From 0f8daa5612807091b4ce90af590f9aeed02f5276 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 24 Dec 2023 08:32:39 -0700 Subject: [PATCH] Bug fixes, work on maing IP adapters more customizable. --- extensions_built_in/sd_trainer/SDTrainer.py | 2 +- toolkit/config_modules.py | 14 ++- toolkit/data_loader.py | 68 +++++++++++-- toolkit/data_transfer_object/data_loader.py | 2 + toolkit/dataloader_mixins.py | 102 +++++++++++++++++--- toolkit/ip_adapter.py | 82 +++++++++++++--- toolkit/prompt_utils.py | 9 +- 7 files changed, 243 insertions(+), 36 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index dd8fdad3..dfd51e72 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -827,7 +827,7 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) prior_pred = None - if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg): + if ((has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction) and not is_reg: with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 41dc9f6a..96074c0d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -151,6 +151,16 @@ class AdapterConfig: self.num_tokens: int = num_tokens self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) + self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid + + + +class ClipTokenMakerConfig: + def __init__(self, **kwargs): + self.image_encoder_path: str = kwargs.get('image_encoder_path', None) + self.num_tokens: int = kwargs.get('num_tokens', 8) + + @@ -393,7 +403,7 @@ class DatasetConfig: None) # focus mask (black and white. White has higher loss than black) self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask - self.mask_min_value: float = kwargs.get('mask_min_value', 0.01) # min value for . 0 - 1 + self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1 self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes self.num_repeats: int = kwargs.get('num_repeats', 1) # number of times to repeat dataset @@ -402,6 +412,8 @@ class DatasetConfig: # cache latents to disk will store them on disk. If both are true, it will save to disk, but keep in memory self.cache_latents_to_disk: bool = kwargs.get('cache_latents_to_disk', False) + self.standardize_images: bool = kwargs.get('standardize_images', False) + # https://albumentations.ai/docs/api_reference/augmentations/transforms # augmentations are returned as a separate image and cannot currently be cached self.augmentations: List[dict] = kwargs.get('augmentations', None) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 4d62b72d..0eaaee4c 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -8,6 +8,7 @@ from typing import List, TYPE_CHECKING import cv2 import numpy as np +import torch from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms @@ -24,6 +25,45 @@ if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion +class RescaleTransform: + """Transform to rescale images to the range [-1, 1].""" + + def __call__(self, image): + return image * 2 - 1 + + +class NormalizeSDXLTransform: + """ + Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images + + Mean: tensor([ 0.0002, -0.1034, -0.1879]) + Standard Deviation: tensor([0.5436, 0.5116, 0.5033]) + """ + + def __call__(self, image): + return transforms.Normalize( + mean=[0.0002, -0.1034, -0.1879], + std=[0.5436, 0.5116, 0.5033], + )(image) + + +class NormalizeSD15Transform: + """ + Transforms the range from 0 to 1 to SDXL mean and std per channel based on avgs over thousands of images + + Mean: tensor([-0.1600, -0.2450, -0.3227]) + Standard Deviation: tensor([0.5319, 0.4997, 0.5139]) + + """ + + def __call__(self, image): + return transforms.Normalize( + mean=[-0.1600, -0.2450, -0.3227], + std=[0.5319, 0.4997, 0.5139], + )(image) + + + class ImageDataset(Dataset, CaptionMixin): def __init__(self, config): self.config = config @@ -63,7 +103,7 @@ class ImageDataset(Dataset, CaptionMixin): self.transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] + RescaleTransform(), ]) def get_config(self, key, default=None, required=False): @@ -200,7 +240,7 @@ class PairedImageDataset(Dataset): self.transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] + RescaleTransform(), ]) def get_all_prompts(self): @@ -368,6 +408,23 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): # repeat the list file_list = file_list * self.dataset_config.num_repeats + if self.dataset_config.standardize_images: + if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: + NormalizeMethod = NormalizeSDXLTransform + else: + NormalizeMethod = NormalizeSD15Transform + + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + NormalizeMethod(), + ]) + else: + self.transform = transforms.Compose([ + transforms.ToTensor(), + RescaleTransform(), + ]) + # this might take a while print(f" - Preprocessing image dimensions") bad_count = 0 @@ -375,7 +432,8 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): try: file_item = FileItemDTO( path=file, - dataset_config=dataset_config + dataset_config=dataset_config, + dataloader_transforms=self.transform, ) self.file_list.append(file_item) except Exception as e: @@ -411,10 +469,6 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset): if self.dataset_config.flip_x or self.dataset_config.flip_y: print(f" - Found {len(self.file_list)} images after adding flips") - self.transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] - ]) self.setup_epoch() diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index df4c2ad7..cc8699d3 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -48,6 +48,7 @@ class FileItemDTO( h, w = img.size self.width: int = w self.height: int = h + self.dataloader_transforms = kwargs.get('dataloader_transforms', None) super().__init__(*args, **kwargs) # self.caption_path: str = kwargs.get('caption_path', None) @@ -64,6 +65,7 @@ class FileItemDTO( self.flip_y: bool = kwargs.get('flip_x', False) self.augments: List[str] = self.dataset_config.augments + self.network_weight: float = self.dataset_config.network_weight self.is_reg = self.dataset_config.is_reg self.tensor: Union[torch.Tensor, None] = None diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 01641be9..71ca10b3 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -56,6 +56,30 @@ transforms_dict = { caption_ext_list = ['txt', 'json', 'caption'] +def standardize_images(images): + """ + Standardize the given batch of images using the specified mean and std. + Expects values of 0 - 1 + + Args: + images (torch.Tensor): A batch of images in the shape of (N, C, H, W), + where N is the number of images, C is the number of channels, + H is the height, and W is the width. + + Returns: + torch.Tensor: Standardized images. + """ + mean = [0.48145466, 0.4578275, 0.40821073] + std = [0.26862954, 0.26130258, 0.27577711] + + # Define the normalization transform + normalize = transforms.Normalize(mean=mean, std=std) + + # Apply normalization to each image in the batch + standardized_images = torch.stack([normalize(img) for img in images]) + + return standardized_images + def clean_caption(caption): # remove any newlines caption = caption.replace('\n', ', ') @@ -520,8 +544,13 @@ class ControlFileItemDTOMixin: )) else: raise Exception("Control images not supported for non-bucket datasets") - - self.control_tensor = transforms.ToTensor()(img) + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + self.control_tensor = self.augment_spatial_control(img, transform=transform) + else: + self.control_tensor = transform(img) def cleanup_control(self: 'FileItemDTO'): self.control_tensor = None @@ -624,6 +653,8 @@ class AugmentationFileItemDTOMixin: self.unaugmented_tensor: Union[torch.Tensor, None] = None # self.augmentations: Union[None, List[Augments]] = None self.dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.aug_transform: Union[None, A.Compose] = None + self.aug_replay_spatial_transforms = None self.build_augmentation_transform() def build_augmentation_transform(self: 'FileItemDTO'): @@ -643,7 +674,8 @@ class AugmentationFileItemDTOMixin: # add the method to the list augmentation_list.append(method(**aug.params)) - self.aug_transform = A.Compose(augmentation_list) + # add additional targets so we can augment the control image + self.aug_transform = A.ReplayCompose(augmentation_list, additional_targets={'image2': 'image'}) def augment_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): @@ -659,7 +691,17 @@ class AugmentationFileItemDTOMixin: open_cv_image = open_cv_image[:, :, ::-1].copy() # apply augmentations - augmented = self.aug_transform(image=open_cv_image)["image"] + transformed = self.aug_transform(image=open_cv_image) + augmented = transformed["image"] + + # save just the spatial transforms for controls and masks + augmented_params = transformed["replay"] + spatial_transforms = ['Rotate', 'Flip', 'HorizontalFlip', 'VerticalFlip', 'Resize', 'Crop', 'RandomCrop', + 'ElasticTransform', 'GridDistortion', 'OpticalDistortion'] + # only store the spatial transforms + augmented_params['transforms'] = [t for t in augmented_params['transforms'] if t['__class_fullname__'].split('.')[-1] in spatial_transforms] + + self.aug_replay_spatial_transforms = augmented_params # convert back to RGB tensor augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) @@ -671,6 +713,38 @@ class AugmentationFileItemDTOMixin: return augmented_tensor + # augment control images spatially consistent with transforms done to the main image + def augment_spatial_control(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose] ): + if self.aug_replay_spatial_transforms is None: + # no transforms + return transform(img) + + # save colorspace to convert back to + colorspace = img.mode + + # convert to rgb + img = img.convert('RGB') + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # Replay transforms + transformed = A.ReplayCompose.replay(self.aug_replay_spatial_transforms, image=open_cv_image) + augmented = transformed["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + # convert back to original colorspace + augmented = augmented.convert(colorspace) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + return augmented_tensor + def cleanup_control(self: 'FileItemDTO'): self.unaugmented_tensor = None @@ -760,7 +834,13 @@ class MaskFileItemDTOMixin: else: raise Exception("Mask images not supported for non-bucket datasets") - self.mask_tensor = transforms.ToTensor()(img) + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + if self.aug_replay_spatial_transforms: + self.mask_tensor = self.augment_spatial_control(img, transform=transform) + else: + self.mask_tensor = transform(img) self.mask_tensor = value_map(self.mask_tensor, 0, 1.0, self.mask_min_value, 1.0) # convert to grayscale @@ -776,12 +856,7 @@ class UnconditionalFileItemDTOMixin: self.unconditional_path: Union[str, None] = None self.unconditional_tensor: Union[torch.Tensor, None] = None self.unconditional_latent: Union[torch.Tensor, None] = None - self.unconditional_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) + self.unconditional_transforms = self.dataloader_transforms dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) if dataset_config.unconditional_path is not None: @@ -835,7 +910,10 @@ class UnconditionalFileItemDTOMixin: else: raise Exception("Unconditional images are not supported for non-bucket datasets") - self.unconditional_tensor = self.unconditional_transforms(img) + if self.aug_replay_spatial_transforms: + self.unconditional_tensor = self.augment_spatial_control(img, transform=self.unconditional_transforms) + else: + self.unconditional_tensor = self.unconditional_transforms(img) def cleanup_unconditional(self: 'FileItemDTO'): self.unconditional_tensor = None diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index a02e1ff2..62bd826d 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -26,8 +26,15 @@ if TYPE_CHECKING: from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, - CLIPVisionModel + CLIPVisionModel, + AutoImageProcessor, + ConvNextModel, + ConvNextForImageClassification, + ConvNextImageProcessor ) +from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification + +from transformers import ViTFeatureExtractor, ViTForImageClassification import torch.nn.functional as F @@ -153,13 +160,51 @@ class IPAdapter(torch.nn.Module): super().__init__() self.config = adapter_config self.sd_ref: weakref.ref = weakref.ref(sd) - try: - self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) - except EnvironmentError: - self.clip_image_processor = CLIPImageProcessor() self.device = self.sd_ref().unet.device - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path, - ignore_mismatched_sizes=True) + if self.config.image_encoder_arch == 'clip': + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( + adapter_config.image_encoder_path, + ignore_mismatched_sizes=True).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit': + try: + self.clip_image_processor = ViTFeatureExtractor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = ViTFeatureExtractor() + self.image_encoder = ViTForImageClassification.from_pretrained(adapter_config.image_encoder_path).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'convnext': + try: + self.clip_image_processor = ConvNextImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ConvNextImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.image_encoder = ConvNextForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + elif self.config.image_encoder_arch == 'vit-hybrid': + try: + self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + print(f"could not load image processor from {adapter_config.image_encoder_path}") + self.clip_image_processor = ViTHybridImageProcessor( + size=320, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + ) + self.image_encoder = ViTHybridForImageClassification.from_pretrained( + adapter_config.image_encoder_path, + use_safetensors=True, + ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + else: + raise ValueError(f"unknown image encoder arch: {adapter_config.image_encoder_arch}") self.current_scale = 1.0 self.is_active = True if adapter_config.type == 'ip': @@ -181,7 +226,7 @@ class IPAdapter(torch.nn.Module): dim_head=64, heads=heads, num_queries=self.config.num_tokens, # usually 16 - embedding_dim=self.image_encoder.config.hidden_size, + embedding_dim=self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1], output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 ) @@ -239,6 +284,10 @@ class IPAdapter(torch.nn.Module): self.set_scale(1.0) + if self.config.train_image_encoder: + self.image_encoder.train() + self.image_encoder.requires_grad_(True) + def to(self, *args, **kwargs): super().to(*args, **kwargs) self.image_encoder.to(*args, **kwargs) @@ -280,7 +329,7 @@ class IPAdapter(torch.nn.Module): if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) if drop: clip_image = clip_image * 0 clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] @@ -307,15 +356,18 @@ class IPAdapter(torch.nn.Module): do_resize=True, do_rescale=False, ).pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16).detach() + clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() if drop: clip_image = clip_image * 0 with torch.set_grad_enabled(is_training): if is_training: self.image_encoder.train() + clip_output = self.image_encoder(clip_image.requires_grad_(True) + , output_hidden_states=True) else: self.image_encoder.eval() - clip_output = self.image_encoder(clip_image, output_hidden_states=True) + clip_output = self.image_encoder(clip_image, output_hidden_states=True) + clip_image_embeds = clip_output.hidden_states[-2] return clip_image_embeds @@ -332,8 +384,16 @@ class IPAdapter(torch.nn.Module): yield from self.image_proj_model.parameters(recurse) if self.config.train_image_encoder: yield from self.image_encoder.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # self.image_encoder.train() + # else: + # for attn_processor in self.adapter_modules: + # yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) if self.config.train_image_encoder and 'image_encoder' in state_dict: diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index 6a5032f3..4a132967 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -39,10 +39,11 @@ class PromptEmbeds: return self def detach(self): - self.text_embeds = self.text_embeds.detach() - if self.pooled_embeds is not None: - self.pooled_embeds = self.pooled_embeds.detach() - return self + new_embeds = self.clone() + new_embeds.text_embeds = new_embeds.text_embeds.detach() + if new_embeds.pooled_embeds is not None: + new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() + return new_embeds def clone(self): if self.pooled_embeds is not None: