diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index a3ff96cc..1d9c99a6 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -787,25 +787,33 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): - with torch.no_grad(): - if has_adapter_img: - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - adapter_images.detach().to(self.device_torch, dtype=dtype)) - elif is_reg: - # we will zero it out in the img embedder - adapter_img = torch.zeros( - (noisy_latents.shape[0], 3, 512, 512), - device=self.device_torch, dtype=dtype - ) - # drop will zero it out - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - adapter_img, drop=True - ) - else: - raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") + if has_adapter_img: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + adapter_images.detach().to(self.device_torch, dtype=dtype), + is_training=True + ) + elif is_reg: + # we will zero it out in the img embedder + adapter_img = torch.zeros( + (noisy_latents.shape[0], 3, 512, 512), + device=self.device_torch, dtype=dtype + ).detach() + # drop will zero it out + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + adapter_img, + drop=True, + is_training=True + ) + else: + raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") + + if not self.adapter_config.train_image_encoder: + # we are not training the image encoder, so we need to detach the embeds + conditional_clip_embeds = conditional_clip_embeds.detach() + with self.timer('encode_adapter'): - conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach()) + conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) prior_pred = None if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg): diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 55b2b9c4..25fbf12a 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -870,6 +870,8 @@ class BaseSDTrainProcess(BaseTrainProcess): sd=self.sd, adapter_config=self.adapter_config, ) + if self.train_config.gradient_checkpointing: + self.adapter.enable_gradient_checkpointing() self.adapter.to(self.device_torch, dtype=dtype) if latest_save_path is not None: # load adapter from path diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e57bd5ec..19e75a91 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -142,6 +142,17 @@ class AdapterConfig: self.image_encoder_path: str = kwargs.get('image_encoder_path', None) self.name_or_path = kwargs.get('name_or_path', None) + num_tokens = kwargs.get('num_tokens', None) + if num_tokens is None and self.type.startswith('ip'): + if self.type == 'ip+': + num_tokens = 16 + elif self.type == 'ip': + num_tokens = 4 + + self.num_tokens: int = num_tokens + self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) + + class EmbeddingConfig: def __init__(self, **kwargs): diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 68fa3bb4..b4eef5b7 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -12,7 +12,8 @@ from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List from collections import OrderedDict -from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, AttnProcessor2_0 +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ + AttnProcessor2_0 from ipadapter.ip_adapter.ip_adapter import ImageProjModel from ipadapter.ip_adapter.resampler import Resampler from toolkit.config_modules import AdapterConfig @@ -25,6 +26,7 @@ if TYPE_CHECKING: from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, + CLIPVisionModel ) import torch.nn.functional as F @@ -151,9 +153,10 @@ class IPAdapter(torch.nn.Module): super().__init__() self.config = adapter_config self.sd_ref: weakref.ref = weakref.ref(sd) - self.clip_image_processor = CLIPImageProcessor() + self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) self.device = self.sd_ref().unet.device - self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path) + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path, + ignore_mismatched_sizes=True) self.current_scale = 1.0 self.is_active = True if adapter_config.type == 'ip': @@ -161,17 +164,16 @@ class IPAdapter(torch.nn.Module): image_proj_model = ImageProjModel( cross_attention_dim=sd.unet.config['cross_attention_dim'], clip_embeddings_dim=self.image_encoder.config.projection_dim, - clip_extra_context_tokens=4, + clip_extra_context_tokens=self.config.num_tokens, # usually 4 ) elif adapter_config.type == 'ip+': # ip-adapter-plus - num_tokens = 16 image_proj_model = Resampler( dim=sd.unet.config['cross_attention_dim'], depth=4, dim_head=64, heads=12, - num_queries=num_tokens, + num_queries=self.config.num_tokens, # usually 16 embedding_dim=self.image_encoder.config.hidden_size, output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 @@ -203,20 +205,12 @@ class IPAdapter(torch.nn.Module): "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } - if adapter_config.type == 'ip': - # ip-adapter - num_tokens = 4 - elif adapter_config.type == 'ip+': - # ip-adapter-plus - num_tokens = 16 - else: - raise ValueError(f"unknown adapter type: {adapter_config.type}") attn_procs[name] = CustomIPAttentionProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, - num_tokens=num_tokens, + num_tokens=self.config.num_tokens, adapter=self ) attn_procs[name].load_state_dict(weights) @@ -249,6 +243,8 @@ class IPAdapter(torch.nn.Module): self.image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"]) # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): # self.load_ip_adapter(state_dict) @@ -257,6 +253,8 @@ class IPAdapter(torch.nn.Module): state_dict = OrderedDict() state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict() + if self.config.train_image_encoder: + state_dict["image_encoder"] = self.image_encoder.state_dict() return state_dict def get_scale(self): @@ -281,37 +279,43 @@ class IPAdapter(torch.nn.Module): clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] return clip_image_embeds - @torch.no_grad() - def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor: - # tensors should be 0-1 - # todo: add support for sdxl - if tensors_0_1.ndim == 3: - tensors_0_1 = tensors_0_1.unsqueeze(0) - # training tensors are 0 - 1 - tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) - # if images are out of this range throw error - if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: - raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( - tensors_0_1.min(), tensors_0_1.max() - )) + def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False, + is_training=False) -> torch.Tensor: + with torch.no_grad(): + # tensors should be 0-1 + # todo: add support for sdxl + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + # training tensors are 0 - 1 + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + # if images are out of this range throw error + if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: + raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( + tensors_0_1.min(), tensors_0_1.max() + )) - clip_image = self.clip_image_processor( - images=tensors_0_1, - return_tensors="pt", - do_resize=True, - do_rescale=False, - ).pixel_values - clip_image = clip_image.to(self.device, dtype=torch.float16).detach() - if drop: - clip_image = clip_image * 0 - clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + clip_image = self.clip_image_processor( + images=tensors_0_1, + return_tensors="pt", + do_resize=True, + do_rescale=False, + ).pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16).detach() + if drop: + clip_image = clip_image * 0 + with torch.set_grad_enabled(is_training): + if is_training: + self.image_encoder.train() + else: + self.image_encoder.eval() + clip_output = self.image_encoder(clip_image, output_hidden_states=True) + clip_image_embeds = clip_output.hidden_states[-2] return clip_image_embeds # use drop for prompt dropout, or negatives def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds: - clip_image_embeds = clip_image_embeds.detach() clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) - image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach()) + image_prompt_embeds = self.image_proj_model(clip_image_embeds) embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) return embeddings @@ -319,7 +323,14 @@ class IPAdapter(torch.nn.Module): for attn_processor in self.adapter_modules: yield from attn_processor.parameters(recurse) yield from self.image_proj_model.parameters(recurse) + if self.config.train_image_encoder: + yield from self.image_encoder.parameters(recurse) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) + if self.config.train_image_encoder and 'image_encoder' in state_dict: + self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True