diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 1d9c99a6..b91c1c5d 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -567,6 +567,10 @@ class SDTrainer(BaseSDTrainProcess): network_weight_list = network_weight_list + network_weight_list has_adapter_img = batch.control_tensor is not None + has_clip_image = batch.clip_image_tensor is not None + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: + raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") match_adapter_assist = False @@ -604,10 +608,14 @@ class SDTrainer(BaseSDTrainProcess): adapter_images = adapter_images[:, :in_channels, :, :] else: raise NotImplementedError("Adapter images now must be loaded with dataloader") - # not 100% sure what this does. But they do it here - # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 - # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) - # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) + + clip_images = None + if has_clip_image: + with self.timer('get_clip_images'): + # todo move this to data loader + if batch.clip_image_tensor is not None: + clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() + mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) if batch.mask_tensor is not None: @@ -697,6 +705,10 @@ class SDTrainer(BaseSDTrainProcess): adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) else: adapter_images_list = [None for _ in range(batch_size)] + if clip_images is not None: + clip_images_list = torch.chunk(clip_images, batch_size, dim=0) + else: + clip_images_list = [None for _ in range(batch_size)] mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) if prompts_2 is None: prompt_2_list = [None for _ in range(batch_size)] @@ -710,19 +722,21 @@ class SDTrainer(BaseSDTrainProcess): conditioned_prompts_list = [prompts_1] imgs_list = [imgs] adapter_images_list = [adapter_images] + clip_images_list = [clip_images] mask_multiplier_list = [mask_multiplier] if prompts_2 is None: prompt_2_list = [None] else: prompt_2_list = [prompts_2] - for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip( + for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( noisy_latents_list, noise_list, timesteps_list, conditioned_prompts_list, imgs_list, adapter_images_list, + clip_images_list, mask_multiplier_list, prompt_2_list ): @@ -766,7 +780,7 @@ class SDTrainer(BaseSDTrainProcess): if has_adapter_img and ( (self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter): with torch.set_grad_enabled(self.adapter is not None): - adapter = self.adapter if self.adapter else self.assistant_adapter + adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter adapter_multiplier = get_adapter_multiplier() with self.timer('encode_adapter'): down_block_additional_residuals = adapter(adapter_images) @@ -787,20 +801,20 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): - if has_adapter_img: + if has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - adapter_images.detach().to(self.device_torch, dtype=dtype), + clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True ) elif is_reg: # we will zero it out in the img embedder - adapter_img = torch.zeros( + clip_images = torch.zeros( (noisy_latents.shape[0], 3, 512, 512), device=self.device_torch, dtype=dtype ).detach() # drop will zero it out conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - adapter_img, + clip_images, drop=True, is_training=True ) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 19e75a91..cd3c67db 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -421,6 +421,9 @@ class DatasetConfig: self.caption_type = self.caption_ext self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted') + # ip adapter / reference dataset + self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc + def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: """ diff --git a/toolkit/data_transfer_object/data_loader.py b/toolkit/data_transfer_object/data_loader.py index f71add54..df4c2ad7 100644 --- a/toolkit/data_transfer_object/data_loader.py +++ b/toolkit/data_transfer_object/data_loader.py @@ -8,7 +8,7 @@ from PIL.ImageOps import exif_transpose from toolkit import image_utils from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \ - UnconditionalFileItemDTOMixin + UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin if TYPE_CHECKING: from toolkit.config_modules import DatasetConfig @@ -28,6 +28,7 @@ class FileItemDTO( CaptionProcessingDTOMixin, ImageProcessingDTOMixin, ControlFileItemDTOMixin, + ClipImageFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, UnconditionalFileItemDTOMixin, @@ -71,6 +72,7 @@ class FileItemDTO( self.tensor = None self.cleanup_latent() self.cleanup_control() + self.cleanup_clip_image() self.cleanup_mask() self.cleanup_unconditional() @@ -83,6 +85,7 @@ class DataLoaderBatchDTO: self.tensor: Union[torch.Tensor, None] = None self.latents: Union[torch.Tensor, None] = None self.control_tensor: Union[torch.Tensor, None] = None + self.clip_image_tensor: Union[torch.Tensor, None] = None self.mask_tensor: Union[torch.Tensor, None] = None self.unaugmented_tensor: Union[torch.Tensor, None] = None self.unconditional_tensor: Union[torch.Tensor, None] = None @@ -113,6 +116,21 @@ class DataLoaderBatchDTO: control_tensors.append(x.control_tensor) self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors]) + if any([x.clip_image_tensor is not None for x in self.file_items]): + # find one to use as a base + base_clip_image_tensor = None + for x in self.file_items: + if x.clip_image_tensor is not None: + base_clip_image_tensor = x.clip_image_tensor + break + clip_image_tensors = [] + for x in self.file_items: + if x.clip_image_tensor is None: + clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor)) + else: + clip_image_tensors.append(x.clip_image_tensor) + self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors]) + if any([x.mask_tensor is not None for x in self.file_items]): # find one to use as a base base_mask_tensor = None diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index dd8dfa19..f513f46d 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -350,6 +350,8 @@ class ImageProcessingDTOMixin: self.get_latent() if self.has_control_image: self.load_control_image() + if self.has_clip_image: + self.load_clip_image() if self.has_mask_image: self.load_mask_image() if self.has_unconditional: @@ -443,6 +445,8 @@ class ImageProcessingDTOMixin: if not only_load_latents: if self.has_control_image: self.load_control_image() + if self.has_clip_image: + self.load_clip_image() if self.has_mask_image: self.load_mask_image() if self.has_unconditional: @@ -523,6 +527,46 @@ class ControlFileItemDTOMixin: self.control_tensor = None +class ClipImageFileItemDTOMixin: + def __init__(self: 'FileItemDTO', *args, **kwargs): + if hasattr(super(), '__init__'): + super().__init__(*args, **kwargs) + self.has_clip_image = False + self.clip_image_path: Union[str, None] = None + self.clip_image_tensor: Union[torch.Tensor, None] = None + dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + if dataset_config.clip_image_path is not None: + # find the control image path + clip_image_path = dataset_config.clip_image_path + # we are using control images + img_path = kwargs.get('path', None) + img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] + file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0] + for ext in img_ext_list: + if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)): + self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext) + self.has_clip_image = True + break + + def load_clip_image(self: 'FileItemDTO'): + img = Image.open(self.clip_image_path).convert('RGB') + try: + img = exif_transpose(img) + except Exception as e: + print(f"Error: {e}") + print(f"Error loading image: {self.clip_image_path}") + + # we just scale them to 512x512: + img = img.resize((512, 512), Image.BICUBIC) + + self.clip_image_tensor = transforms.ToTensor()(img) + + def cleanup_clip_image(self: 'FileItemDTO'): + self.clip_image_tensor = None + + + + class AugmentationFileItemDTOMixin: def __init__(self: 'FileItemDTO', *args, **kwargs): if hasattr(super(), '__init__'): diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index b4eef5b7..d4967c6a 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -153,7 +153,10 @@ class IPAdapter(torch.nn.Module): super().__init__() self.config = adapter_config self.sd_ref: weakref.ref = weakref.ref(sd) - self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + try: + self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) + except EnvironmentError: + self.clip_image_processor = CLIPImageProcessor() self.device = self.sd_ref().unet.device self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path, ignore_mismatched_sizes=True)