diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 984bd7dc..a3ff96cc 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -94,9 +94,18 @@ class SDTrainer(BaseSDTrainProcess): noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) - if self.train_config.inverted_mask_prior: + if self.train_config.inverted_mask_prior and prior_pred is not None: # we need to make the noise prediction be a masked blending of noise and prior_pred - prior_mask_multiplier = 1.0 - mask_multiplier + stretched_mask_multiplier = value_map( + mask_multiplier, + batch.file_items[0].dataset_config.mask_min_value, + 1.0, + 0.0, + 1.0 + ) + + prior_mask_multiplier = 1.0 - stretched_mask_multiplier + # target_mask_multiplier = mask_multiplier # mask_multiplier = 1.0 target = noise @@ -152,7 +161,8 @@ class SDTrainer(BaseSDTrainProcess): # multiply by our mask loss = loss * mask_multiplier - if self.train_config.inverted_mask_prior: + prior_loss = None + if self.train_config.inverted_mask_prior and prior_pred is not None: # to a loss to unmasked areas of the prior for unmasked regularization prior_loss = torch.nn.functional.mse_loss( prior_pred.float(), @@ -160,10 +170,16 @@ class SDTrainer(BaseSDTrainProcess): reduction="none" ) prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier - loss = loss + prior_loss + if torch.isnan(prior_loss).any(): + raise ValueError("Prior loss is nan") + + prior_loss = prior_loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3]) + if prior_loss is not None: + loss = loss + prior_loss + if self.train_config.learnable_snr_gos: # add snr_gamma loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) @@ -491,14 +507,25 @@ class SDTrainer(BaseSDTrainProcess): noise: torch.Tensor, **kwargs ): + was_unet_training = self.sd.unet.training + was_network_active = False + if self.network is not None: + was_network_active = self.network.is_active + self.network.is_active = False + is_ip_adapter = False + was_ip_adapter_active = False + if self.adapter is not None and isinstance(self.adapter, IPAdapter): + is_ip_adapter = True + was_ip_adapter_active = self.adapter.is_active + self.adapter.is_active = False + # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) # dont use network on this # self.network.multiplier = 0.0 - was_network_active = self.network.is_active - self.network.is_active = False self.sd.unet.eval() + prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), @@ -506,14 +533,19 @@ class SDTrainer(BaseSDTrainProcess): guidance_scale=1.0, **pred_kwargs # adapter residuals in here ) - self.sd.unet.train() + if was_unet_training: + self.sd.unet.train() prior_pred = prior_pred.detach() # remove the residuals as we wont use them on prediction when matching control if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: del pred_kwargs['down_intrablock_additional_residuals'] + + if is_ip_adapter: + self.adapter.is_active = was_ip_adapter_active # restore network # self.network.multiplier = network_weight_list - self.network.is_active = was_network_active + if self.network is not None: + self.network.is_active = was_network_active return prior_pred def before_unet_predict(self): @@ -752,19 +784,6 @@ class SDTrainer(BaseSDTrainProcess): pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals - prior_pred = None - if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction: - with self.timer('prior predict'): - prior_pred = self.get_prior_prediction( - noisy_latents=noisy_latents, - conditional_embeds=conditional_embeds, - match_adapter_assist=match_adapter_assist, - network_weight_list=network_weight_list, - timesteps=timesteps, - pred_kwargs=pred_kwargs, - noise=noise, - batch=batch, - ) if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): @@ -788,6 +807,20 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_adapter'): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach()) + prior_pred = None + if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg): + with self.timer('prior predict'): + prior_pred = self.get_prior_prediction( + noisy_latents=noisy_latents, + conditional_embeds=conditional_embeds, + match_adapter_assist=match_adapter_assist, + network_weight_list=network_weight_list, + timesteps=timesteps, + pred_kwargs=pred_kwargs, + noise=noise, + batch=batch, + ) + self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later if batch.unconditional_latents is not None or self.do_guided_loss: diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index f3aa6ff7..dd8dfa19 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -480,7 +480,12 @@ class ControlFileItemDTOMixin: print(f"Error: {e}") print(f"Error loading image: {self.control_path}") - if not self.full_size_control_images: + if self.full_size_control_images: + # we just scale them to 512x512: + w, h = img.size + img = img.resize((512, 512), Image.BICUBIC) + + else: w, h = img.size if w > h and self.scale_to_width < self.scale_to_height: # throw error, they should match diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index d7d9fa5a..68fa3bb4 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -12,7 +12,7 @@ from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List from collections import OrderedDict -from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, AttnProcessor2_0 from ipadapter.ip_adapter.ip_adapter import ImageProjModel from ipadapter.ip_adapter.resampler import Resampler from toolkit.config_modules import AdapterConfig @@ -27,10 +27,120 @@ from transformers import ( CLIPVisionModelWithProjection, ) -from diffusers.models.attention_processor import ( - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, -) +import torch.nn.functional as F + + +class CustomIPAttentionProcessor(IPAttnProcessor2_0): + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None): + super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + is_active = self.adapter_ref().is_active + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + # will be none if disabled + if not is_active: + ip_hidden_states = None + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # will be none if disabled + if ip_hidden_states is not None: + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states # loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py @@ -44,6 +154,8 @@ class IPAdapter(torch.nn.Module): self.clip_image_processor = CLIPImageProcessor() self.device = self.sd_ref().unet.device self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path) + self.current_scale = 1.0 + self.is_active = True if adapter_config.type == 'ip': # ip-adapter image_proj_model = ImageProjModel( @@ -84,14 +196,29 @@ class IPAdapter(torch.nn.Module): # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") if cross_attention_dim is None: - attn_procs[name] = AttnProcessor() + attn_procs[name] = AttnProcessor2_0() else: layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } - attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + if adapter_config.type == 'ip': + # ip-adapter + num_tokens = 4 + elif adapter_config.type == 'ip+': + # ip-adapter-plus + num_tokens = 16 + else: + raise ValueError(f"unknown adapter type: {adapter_config.type}") + + attn_procs[name] = CustomIPAttentionProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_tokens, + adapter=self + ) attn_procs[name].load_state_dict(weights) sd.unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) @@ -132,13 +259,18 @@ class IPAdapter(torch.nn.Module): state_dict["ip_adapter"] = self.adapter_modules.state_dict() return state_dict + def get_scale(self): + return self.current_scale + def set_scale(self, scale): + self.current_scale = scale for attn_processor in self.sd_ref().unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @torch.no_grad() - def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor: + def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], + drop=False) -> torch.Tensor: # todo: add support for sdxl if isinstance(pil_image, Image.Image): pil_image = [pil_image] @@ -191,4 +323,3 @@ class IPAdapter(torch.nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) -