diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 452e028c..984bd7dc 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -549,11 +549,13 @@ class SDTrainer(BaseSDTrainProcess): self.timer.stop('preprocess_batch') + is_reg = False with torch.no_grad(): loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) for idx, file_item in enumerate(batch.file_items): if file_item.is_reg: loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + is_reg = True adapter_images = None sigmas = None @@ -764,11 +766,27 @@ class SDTrainer(BaseSDTrainProcess): batch=batch, ) - if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter): - with self.timer('encode_adapter'): + if self.adapter and isinstance(self.adapter, IPAdapter): + with self.timer('encode_adapter_embeds'): with torch.no_grad(): - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) - conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) + if has_adapter_img: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + adapter_images.detach().to(self.device_torch, dtype=dtype)) + elif is_reg: + # we will zero it out in the img embedder + adapter_img = torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ) + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + adapter_img, drop=True + ) + else: + raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") + + with self.timer('encode_adapter'): + conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach()) self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8bf45ca3..55b2b9c4 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -880,13 +880,15 @@ class BaseSDTrainProcess(BaseTrainProcess): self.device, dtype=dtype ) + self.adapter.load_state_dict(loaded_state_dict) else: + # ip adapter loaded_state_dict = load_ip_adapter_model( latest_save_path, self.device, dtype=dtype ) - self.adapter.load_state_dict(loaded_state_dict) + self.adapter.load_state_dict(loaded_state_dict) if self.adapter_config.train: self.load_training_state_from_metadata(latest_save_path) # set trainable params diff --git a/repositories/ipadapter b/repositories/ipadapter index d8ab37c4..f71c943b 160000 --- a/repositories/ipadapter +++ b/repositories/ipadapter @@ -1 +1 @@ -Subproject commit d8ab37c421c1ab95d15abe094e8266a6d01e26ef +Subproject commit f71c943b7e1d3ffccae8e4f04b9adebac037e19f diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 454b923d..e57bd5ec 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -375,6 +375,8 @@ class DatasetConfig: self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) self.control_path: str = kwargs.get('control_path', None) # depth maps, etc + # instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters) + self.full_size_control_images: bool = kwargs.get('full_size_control_images', False) self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.mask_path: str = kwargs.get('mask_path', None) # focus mask (black and white. White has higher loss than black) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 4842556a..f3aa6ff7 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -457,9 +457,11 @@ class ControlFileItemDTOMixin: self.control_path: Union[str, None] = None self.control_tensor: Union[torch.Tensor, None] = None dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None) + self.full_size_control_images = False if dataset_config.control_path is not None: # find the control image path control_path = dataset_config.control_path + self.full_size_control_images = dataset_config.full_size_control_images # we are using control images img_path = kwargs.get('path', None) img_ext_list = ['.jpg', '.jpeg', '.png', '.webp'] @@ -477,36 +479,38 @@ class ControlFileItemDTOMixin: except Exception as e: print(f"Error: {e}") print(f"Error loading image: {self.control_path}") - w, h = img.size - if w > h and self.scale_to_width < self.scale_to_height: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - elif h > w and self.scale_to_height < self.scale_to_width: - # throw error, they should match - raise ValueError( - f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - if self.flip_x: - # do a flip - img.transpose(Image.FLIP_LEFT_RIGHT) - if self.flip_y: - # do a flip - img.transpose(Image.FLIP_TOP_BOTTOM) + if not self.full_size_control_images: + w, h = img.size + if w > h and self.scale_to_width < self.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") + elif h > w and self.scale_to_height < self.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}") - if self.dataset_config.buckets: - # scale and crop based on file item - img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) - # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) - # crop - img = img.crop(( - self.crop_x, - self.crop_y, - self.crop_x + self.crop_width, - self.crop_y + self.crop_height - )) - else: - raise Exception("Control images not supported for non-bucket datasets") + if self.flip_x: + # do a flip + img.transpose(Image.FLIP_LEFT_RIGHT) + if self.flip_y: + # do a flip + img.transpose(Image.FLIP_TOP_BOTTOM) + + if self.dataset_config.buckets: + # scale and crop based on file item + img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) + # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) + # crop + img = img.crop(( + self.crop_x, + self.crop_y, + self.crop_x + self.crop_width, + self.crop_y + self.crop_height + )) + else: + raise Exception("Control images not supported for non-bucket datasets") self.control_tensor = transforms.ToTensor()(img) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index f252af28..d7d9fa5a 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -6,6 +6,7 @@ from torch.nn import Parameter from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.paths import REPOS_ROOT +from toolkit.saving import load_ip_adapter_model from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) @@ -21,6 +22,16 @@ import weakref if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection, +) + +from diffusers.models.attention_processor import ( + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) + # loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py class IPAdapter(torch.nn.Module): @@ -89,6 +100,16 @@ class IPAdapter(torch.nn.Module): self.unet_ref: weakref.ref = weakref.ref(sd.unet) self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules + # load the weights if we have some + if self.config.name_or_path: + loaded_state_dict = load_ip_adapter_model( + self.config.name_or_path, + device='cpu', + dtype=sd.torch_dtype + ) + self.load_state_dict(loaded_state_dict) + + self.set_scale(1.0) def to(self, *args, **kwargs): super().to(*args, **kwargs) @@ -102,6 +123,9 @@ class IPAdapter(torch.nn.Module): ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) + # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): + # self.load_ip_adapter(state_dict) + def state_dict(self) -> OrderedDict: state_dict = OrderedDict() state_dict["image_proj"] = self.image_proj_model.state_dict() @@ -109,7 +133,7 @@ class IPAdapter(torch.nn.Module): return state_dict def set_scale(self, scale): - for attn_processor in self.pipe.unet.attn_processors.values(): + for attn_processor in self.sd_ref().unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @@ -131,9 +155,21 @@ class IPAdapter(torch.nn.Module): # todo: add support for sdxl if tensors_0_1.ndim == 3: tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) - clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) + + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16).detach() if drop: clip_image = clip_image * 0 clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] @@ -155,3 +191,4 @@ class IPAdapter(torch.nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 6bfe3e5e..18d78a52 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -474,9 +474,7 @@ class StableDiffusion: extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale if isinstance(self.adapter, IPAdapter): transform = transforms.Compose([ - transforms.Resize(gen_config.width, - interpolation=transforms.InterpolationMode.BILINEAR), - transforms.PILToTensor(), + transforms.ToTensor(), ]) validation_image = transform(validation_image) @@ -500,6 +498,7 @@ class StableDiffusion: if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None: + # apply the image projection conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,