From 7703e3a15eba29bfe265b9b3df3930320711996c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 21 Dec 2023 11:15:58 -0700 Subject: [PATCH] Fixes for sdxl ip adapter training. Bug fixes --- extensions_built_in/sd_trainer/SDTrainer.py | 7 +-- toolkit/config_modules.py | 2 + toolkit/dataloader_mixins.py | 53 ++++++++++++++++++++- toolkit/image_utils.py | 2 +- toolkit/ip_adapter.py | 10 ++-- 5 files changed, 63 insertions(+), 11 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index b91c1c5d..dd8fdad3 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -173,12 +173,9 @@ class SDTrainer(BaseSDTrainProcess): if torch.isnan(prior_loss).any(): raise ValueError("Prior loss is nan") - prior_loss = prior_loss.mean([1, 2, 3]) - - loss = loss.mean([1, 2, 3]) - - if prior_loss is not None: + # prior_loss = prior_loss.mean([1, 2, 3]) loss = loss + prior_loss + loss = loss.mean([1, 2, 3]) if self.train_config.learnable_snr_gos: # add snr_gamma diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index cd3c67db..41dc9f6a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -423,6 +423,8 @@ class DatasetConfig: # ip adapter / reference dataset self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc + self.clip_image_augmentations: List[dict] = kwargs.get('clip_image_augmentations', None) + self.clip_image_shuffle_augmentations: bool = kwargs.get('clip_image_shuffle_augmentations', False) def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index f513f46d..01641be9 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -534,6 +534,8 @@ class ClipImageFileItemDTOMixin: self.has_clip_image = False self.clip_image_path: Union[str, None] = None self.clip_image_tensor: Union[torch.Tensor, None] = None + self.has_clip_augmentations = False + self.clip_image_aug_transform: Union[None, A.Compose] = None dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) if dataset_config.clip_image_path is not None: # find the control image path @@ -548,6 +550,51 @@ class ClipImageFileItemDTOMixin: self.has_clip_image = True break + self.build_clip_imag_augmentation_transform() + + def build_clip_imag_augmentation_transform(self: 'FileItemDTO'): + if self.dataset_config.clip_image_augmentations is not None and len(self.dataset_config.clip_image_augmentations) > 0: + self.has_clip_augmentations = True + augmentations = [Augments(**aug) for aug in self.dataset_config.clip_image_augmentations] + + if self.dataset_config.clip_image_shuffle_augmentations: + random.shuffle(augmentations) + + augmentation_list = [] + for aug in augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.clip_image_aug_transform = A.Compose(augmentation_list) + + def augment_clip_image(self: 'FileItemDTO', img: Image, transform: Union[None, transforms.Compose], ): + if self.dataset_config.clip_image_shuffle_augmentations: + self.build_clip_imag_augmentation_transform() + + # save the original tensor + self.unaugmented_tensor = transforms.ToTensor()(img) if transform is None else transform(img) + + open_cv_image = np.array(img) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + augmented = self.clip_image_aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + augmented_tensor = transforms.ToTensor()(augmented) if transform is None else transform(augmented) + + return augmented_tensor + def load_clip_image(self: 'FileItemDTO'): img = Image.open(self.clip_image_path).convert('RGB') try: @@ -558,8 +605,10 @@ class ClipImageFileItemDTOMixin: # we just scale them to 512x512: img = img.resize((512, 512), Image.BICUBIC) - - self.clip_image_tensor = transforms.ToTensor()(img) + if self.has_clip_augmentations: + self.clip_image_tensor = self.augment_clip_image(img, transform=None) + else: + self.clip_image_tensor = transforms.ToTensor()(img) def cleanup_clip_image(self: 'FileItemDTO'): self.clip_image_tensor = None diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py index b4dbc38d..21d8df79 100644 --- a/toolkit/image_utils.py +++ b/toolkit/image_utils.py @@ -432,7 +432,7 @@ def show_img(img, name='AI Toolkit'): img = np.clip(img, 0, 255).astype(np.uint8) cv2.imshow(name, img[:, :, ::-1]) - k = cv2.waitKey(10) & 0xFF + k = cv2.waitKey(1) & 0xFF if k == 27: # Esc key to stop print('\nESC pressed, stopping') raise KeyboardInterrupt diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index d4967c6a..a02e1ff2 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -170,12 +170,16 @@ class IPAdapter(torch.nn.Module): clip_extra_context_tokens=self.config.num_tokens, # usually 4 ) elif adapter_config.type == 'ip+': + heads = 12 if not sd.is_xl else 20 + dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 + # size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]). + # size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280]) # ip-adapter-plus image_proj_model = Resampler( - dim=sd.unet.config['cross_attention_dim'], + dim=dim, depth=4, dim_head=64, - heads=12, + heads=heads, num_queries=self.config.num_tokens, # usually 16 embedding_dim=self.image_encoder.config.hidden_size, output_dim=sd.unet.config['cross_attention_dim'], @@ -266,7 +270,7 @@ class IPAdapter(torch.nn.Module): def set_scale(self, scale): self.current_scale = scale for attn_processor in self.sd_ref().unet.attn_processors.values(): - if isinstance(attn_processor, IPAttnProcessor): + if isinstance(attn_processor, CustomIPAttentionProcessor): attn_processor.scale = scale @torch.no_grad()