diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 004b5e7f..c5ede3ca 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -15,6 +15,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \ apply_learnable_snr_gos, LearnableSNRGamma @@ -285,9 +286,11 @@ class SDTrainer(BaseSDTrainProcess): if torch.isnan(prior_loss).any(): raise ValueError("Prior loss is nan") - # prior_loss = prior_loss.mean([1, 2, 3]) - loss = loss + prior_loss + prior_loss = prior_loss.mean([1, 2, 3]) + # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) + if prior_loss is not None: + loss = loss + prior_loss if not self.train_config.train_turbo: if self.train_config.learnable_snr_gos: @@ -623,11 +626,11 @@ class SDTrainer(BaseSDTrainProcess): if self.network is not None: was_network_active = self.network.is_active self.network.is_active = False - is_ip_adapter = False - was_ip_adapter_active = False - if self.adapter is not None and isinstance(self.adapter, IPAdapter): - is_ip_adapter = True - was_ip_adapter_active = self.adapter.is_active + can_disable_adapter = False + was_adapter_active = False + if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)): + can_disable_adapter = True + was_adapter_active = self.adapter.is_active self.adapter.is_active = False # do a prediction here so we can match its output with network multiplier set to 0.0 @@ -666,8 +669,8 @@ class SDTrainer(BaseSDTrainProcess): if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: del pred_kwargs['down_intrablock_additional_residuals'] - if is_ip_adapter: - self.adapter.is_active = was_ip_adapter_active + if can_disable_adapter: + self.adapter.is_active = was_adapter_active # restore network # self.network.multiplier = network_weight_list if self.network is not None: @@ -950,12 +953,7 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): - if has_clip_image: - conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( - clip_images.detach().to(self.device_torch, dtype=dtype), - is_training=True - ) - elif is_reg: + if is_reg: # we will zero it out in the img embedder clip_images = torch.zeros( (noisy_latents.shape[0], 3, 512, 512), @@ -967,6 +965,11 @@ class SDTrainer(BaseSDTrainProcess): drop=True, is_training=True ) + elif has_clip_image: + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + clip_images.detach().to(self.device_torch, dtype=dtype), + is_training=True + ) else: raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") @@ -978,12 +981,26 @@ class SDTrainer(BaseSDTrainProcess): with self.timer('encode_adapter'): conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds) + if self.adapter and isinstance(self.adapter, ReferenceAdapter): + # pass in our scheduler + self.adapter.noise_scheduler = self.lr_scheduler + if has_clip_image or has_adapter_img: + img_to_use = clip_images if has_clip_image else adapter_images + # currently 0-1 needs to be -1 to 1 + reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) + self.adapter.set_reference_images(reference_images) + self.adapter.noise_scheduler = self.sd.noise_scheduler + elif is_reg: + self.adapter.set_blank_reference_images(noisy_latents.shape[0]) + else: + self.adapter.set_reference_images(None) + prior_pred = None do_reg_prior = False - if is_reg and (self.network is not None or self.adapter is not None): - # we are doing a reg image and we have a network or adapter - do_reg_prior = True + # if is_reg and (self.network is not None or self.adapter is not None): + # # we are doing a reg image and we have a network or adapter + # do_reg_prior = True do_inverted_masked_prior = False if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index bb88937e..d7c3d939 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -31,6 +31,7 @@ from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT from toolkit.progress_bar import ToolkitProgressBar +from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \ load_ip_adapter_model @@ -140,7 +141,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None - self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None self.embedding: Union[Embedding, None] = None is_training_adapter = self.adapter_config is not None and self.adapter_config.train @@ -771,8 +772,12 @@ class BaseSDTrainProcess(BaseTrainProcess): num_train_timesteps, device=self.device_torch ) + content_or_style = self.train_config.content_or_style + if is_reg: + content_or_style = self.train_config.content_or_style_reg + # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': - if self.train_config.content_or_style in ['style', 'content']: + if content_or_style in ['style', 'content']: # this is from diffusers training code # Cubic sampling for favoring later or earlier timesteps # For more details about why cubic sampling is used for content / structure, @@ -783,9 +788,9 @@ class BaseSDTrainProcess(BaseTrainProcess): orig_timesteps = torch.rand((batch_size,), device=latents.device) - if self.train_config.content_or_style == 'content': + if content_or_style == 'content': timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps'] - elif self.train_config.content_or_style == 'style': + elif content_or_style == 'style': timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps'] timestep_indices = value_map( @@ -800,7 +805,7 @@ class BaseSDTrainProcess(BaseTrainProcess): max_noise_steps - 1 ) - elif self.train_config.content_or_style == 'balanced': + elif content_or_style == 'balanced': if min_noise_steps == max_noise_steps: timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps else: @@ -813,7 +818,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) timestep_indices = timestep_indices.long() else: - raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}") + raise ValueError(f"Unknown content_or_style {content_or_style}") # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] @@ -824,9 +829,16 @@ class BaseSDTrainProcess(BaseTrainProcess): height=latents.shape[2], width=latents.shape[3], batch_size=batch_size, - noise_offset=self.train_config.noise_offset + noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) + # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents + # this will negate any noise offsets + if self.train_config.dynamic_noise_offset and not is_reg: + latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2 + # subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel + noise = noise + latents_channel_mean + if self.train_config.loss_target == 'differential_noise': differential = latents - unaugmented_latents # add noise to differential @@ -912,6 +924,8 @@ class BaseSDTrainProcess(BaseTrainProcess): suffix = 't2i' elif self.adapter_config.type == 'clip': suffix = 'clip' + elif self.adapter_config.type == 'reference': + suffix = 'ref' else: suffix = 'ip' adapter_name = self.name @@ -943,6 +957,11 @@ class BaseSDTrainProcess(BaseTrainProcess): sd=self.sd, adapter_config=self.adapter_config, ) + elif self.adapter_config.type == 'reference': + self.adapter = ReferenceAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) else: self.adapter = IPAdapter( sd=self.sd, @@ -1441,6 +1460,8 @@ class BaseSDTrainProcess(BaseTrainProcess): did_first_flush = True # flush() # setup the networks to gradient checkpointing and everything works + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() with torch.no_grad(): # torch.cuda.empty_cache() diff --git a/toolkit/basic.py b/toolkit/basic.py index 6a70bf61..0d32a9d2 100644 --- a/toolkit/basic.py +++ b/toolkit/basic.py @@ -31,12 +31,18 @@ def get_mean_std(tensor): def adain(content_features, style_features): # Assumes that the content and style features are of shape (batch_size, channels, width, height) + dims = [2, 3] + if len(content_features.shape) == 3: + # content_features = content_features.unsqueeze(0) + # style_features = style_features.unsqueeze(0) + dims = [1] + # Step 1: Calculate mean and variance of content features - content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features, - dim=[2, 3], + content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features, + dim=dims, keepdim=True) # Step 2: Calculate mean and variance of style features - style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3], + style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims, keepdim=True) # Step 3: Normalize content features diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 05277d52..387d2a9d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -178,6 +178,7 @@ class TrainConfig: def __init__(self, **kwargs): self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm') self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') + self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced') self.steps: int = kwargs.get('steps', 1000) self.lr = kwargs.get('lr', 1e-6) self.unet_lr = kwargs.get('unet_lr', self.lr) @@ -268,6 +269,8 @@ class TrainConfig: if self.train_turbo and not self.noise_scheduler.startswith("euler"): raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers") + self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) + class ModelConfig: def __init__(self, **kwargs): diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index efa9ac07..f01977c3 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -232,6 +232,10 @@ class IPAdapter(torch.nn.Module): elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 + embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1] + + if self.config.image_encoder_arch == 'safe': + embedding_dim = self.config.safe_channels # size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]). # size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280]) # ip-adapter-plus @@ -241,7 +245,7 @@ class IPAdapter(torch.nn.Module): dim_head=64, heads=heads, num_queries=self.config.num_tokens, # usually 16 - embedding_dim=self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1], + embedding_dim=embedding_dim, output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 ) diff --git a/toolkit/reference_adapter.py b/toolkit/reference_adapter.py new file mode 100644 index 00000000..90995ffc --- /dev/null +++ b/toolkit/reference_adapter.py @@ -0,0 +1,411 @@ +import math + +import torch +import sys + +from PIL import Image +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.basic import adain +from toolkit.paths import REPOS_ROOT +from toolkit.saving import load_ip_adapter_model +from toolkit.train_tools import get_torch_dtype + +sys.path.append(REPOS_ROOT) +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict +from collections import OrderedDict +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ + AttnProcessor2_0 +from ipadapter.ip_adapter.ip_adapter import ImageProjModel +from ipadapter.ip_adapter.resampler import Resampler +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + +from diffusers import ( + EulerDiscreteScheduler, + DDPMScheduler, +) + +from transformers import ( + CLIPImageProcessor, + CLIPVisionModelWithProjection +) +from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel + +from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification + +from transformers import ViTFeatureExtractor, ViTForImageClassification + +import torch.nn.functional as F +import torch.nn as nn + + +class ReferenceAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.ref_net = nn.Linear(hidden_size, hidden_size) + self.blend = nn.Parameter(torch.zeros(hidden_size)) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self._memory = None + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + if self.adapter_ref().is_active: + if self.adapter_ref().reference_mode == "write": + # write_mode + memory_ref = self.ref_net(hidden_states) + self._memory = memory_ref + elif self.adapter_ref().reference_mode == "read": + # read_mode + if self._memory is None: + print("Warning: no memory to read from") + else: + + saved_hidden_states = self._memory + try: + new_hidden_states = saved_hidden_states + blend = self.blend + # expand the blend buyt keep dim 0 the same (batch) + while blend.ndim < new_hidden_states.ndim: + blend = blend.unsqueeze(0) + # expand batch + blend = torch.cat([blend] * new_hidden_states.shape[0], dim=0) + hidden_states = blend * new_hidden_states + (1 - blend) * hidden_states + except Exception as e: + raise Exception(f"Error blending: {e}") + + return hidden_states + + +class ReferenceAdapter(torch.nn.Module): + + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.device = self.sd_ref().unet.device + self.reference_mode = "read" + self.current_scale = 1.0 + self.is_active = True + self._reference_images = None + self._reference_latents = None + self.has_memory = False + + self.noise_scheduler: Union[DDPMScheduler, EulerDiscreteScheduler] = None + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + for name in sd.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor2_0() + else: + # layer_name = name.split(".processor")[0] + # weights = { + # "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + # "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + # } + + attn_procs[name] = ReferenceAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self + ) + # attn_procs[name].load_state_dict(weights) + sd.unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + sd.adapter = self + self.unet_ref: weakref.ref = weakref.ref(sd.unet) + self.adapter_modules = adapter_modules + # load the weights if we have some + if self.config.name_or_path: + loaded_state_dict = load_ip_adapter_model( + self.config.name_or_path, + device='cpu', + dtype=sd.torch_dtype + ) + self.load_state_dict(loaded_state_dict) + + self.set_scale(1.0) + self.attach() + self.to(self.device, self.sd_ref().torch_dtype) + + # if self.config.train_image_encoder: + # self.image_encoder.train() + # self.image_encoder.requires_grad_(True) + + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + # self.image_encoder.to(*args, **kwargs) + # self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + return self + + def load_reference_adapter(self, state_dict: Union[OrderedDict, dict]): + reference_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + reference_layers.load_state_dict(state_dict["reference_adapter"]) + + # def load_state_dict(self, state_dict: Union[OrderedDict, dict]): + # self.load_ip_adapter(state_dict) + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + state_dict["reference_adapter"] = self.adapter_modules.state_dict() + return state_dict + + def get_scale(self): + return self.current_scale + + def set_reference_images(self, reference_images: Optional[torch.Tensor]): + self._reference_images = reference_images.clone().detach() + self._reference_latents = None + self.clear_memory() + + def set_blank_reference_images(self, batch_size): + self._reference_images = torch.zeros((batch_size, 3, 512, 512), device=self.device, dtype=self.sd_ref().torch_dtype) + self._reference_latents = torch.zeros((batch_size, 4, 64, 64), device=self.device, dtype=self.sd_ref().torch_dtype) + self.clear_memory() + + + def set_scale(self, scale): + self.current_scale = scale + for attn_processor in self.sd_ref().unet.attn_processors.values(): + if isinstance(attn_processor, ReferenceAttnProcessor2_0): + attn_processor.scale = scale + + + def attach(self): + unet = self.sd_ref().unet + self._original_unet_forward = unet.forward + unet.forward = lambda *args, **kwargs: self.unet_forward(*args, **kwargs) + if self.sd_ref().network is not None: + # set network to not merge in + self.sd_ref().network.can_merge_in = False + + def unet_forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): + skip = False + if self._reference_images is None and self._reference_latents is None: + skip = True + if not self.is_active: + skip = True + + if self.has_memory: + skip = True + + if not skip: + if self.sd_ref().network is not None: + self.sd_ref().network.is_active = True + if self.sd_ref().network.is_merged_in: + raise ValueError("network is merged in, but we are not supposed to be merged in") + # send it through our forward first + self.forward(sample, timestep, encoder_hidden_states, *args, **kwargs) + + if self.sd_ref().network is not None: + self.sd_ref().network.is_active = False + + # Send it through the original unet forward + return self._original_unet_forward(sample, timestep, encoder_hidden_states, args, **kwargs) + + + # use drop for prompt dropout, or negatives + def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs): + if not self.noise_scheduler: + raise ValueError("noise scheduler not set") + if not self.is_active or (self._reference_images is None and self._reference_latents is None): + raise ValueError("reference adapter not active or no reference images set") + # todo may need to handle cfg? + self.reference_mode = "write" + + if self._reference_latents is None: + self._reference_latents = self.sd_ref().encode_images(self._reference_images.to( + self.device, self.sd_ref().torch_dtype + )).detach() + # create a sample from our reference images + reference_latents = self._reference_latents.clone().detach().to(self.device, self.sd_ref().torch_dtype) + # if our num of samples are half of incoming, we are doing cfg. Zero out the first half (unconditional) + if reference_latents.shape[0] * 2 == sample.shape[0]: + # we are doing cfg + # Unconditional goes first + reference_latents = torch.cat([torch.zeros_like(reference_latents), reference_latents], dim=0).detach() + + # resize it so reference_latents will fit inside sample in the center + width_scale = sample.shape[2] / reference_latents.shape[2] + height_scale = sample.shape[3] / reference_latents.shape[3] + scale = min(width_scale, height_scale) + # resize the reference latents + + mode = "bilinear" if scale > 1.0 else "bicubic" + + reference_latents = F.interpolate( + reference_latents, + size=(int(reference_latents.shape[2] * scale), int(reference_latents.shape[3] * scale)), + mode=mode, + align_corners=False + ) + + # add 0 padding if needed + width_pad = (sample.shape[2] - reference_latents.shape[2]) / 2 + height_pad = (sample.shape[3] - reference_latents.shape[3]) / 2 + reference_latents = F.pad( + reference_latents, + (math.floor(width_pad), math.floor(width_pad), math.ceil(height_pad), math.ceil(height_pad)), + mode="constant", + value=0 + ) + + # resize again just to make sure it is exact same size + reference_latents = F.interpolate( + reference_latents, + size=(sample.shape[2], sample.shape[3]), + mode="bicubic", + align_corners=False + ) + + # todo maybe add same noise to the sample? For now we will send it through with no noise + # sample_imgs = self.noise_scheduler.add_noise(sample_imgs, timestep) + self._original_unet_forward(reference_latents, timestep, encoder_hidden_states, *args, **kwargs) + self.reference_mode = "read" + self.has_memory = True + return None + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + for attn_processor in self.adapter_modules: + yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # if self.config.train_image_encoder: + # yield from self.image_encoder.parameters(recurse) + # self.image_encoder.train() + # else: + # for attn_processor in self.adapter_modules: + # yield from attn_processor.parameters(recurse) + # yield from self.image_proj_model.parameters(recurse) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + strict = False + # self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["reference_adapter"], strict=strict) + + def enable_gradient_checkpointing(self): + self.image_encoder.gradient_checkpointing = True + + def clear_memory(self): + for attn_processor in self.adapter_modules: + if isinstance(attn_processor, ReferenceAttnProcessor2_0): + attn_processor._memory = None + self.has_memory = False diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 72b442d6..f1170355 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -27,6 +27,7 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds +from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers from toolkit.sd_device_states_presets import empty_preset @@ -76,6 +77,7 @@ class BlankNetwork: self.multiplier = 1.0 self.is_active = True self.is_merged_in = False + self.can_merge_in = False def __enter__(self): self.is_active = True @@ -134,7 +136,7 @@ class StableDiffusion: # to hold network if there is one self.network = None - self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None + self.adapter: Union['T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 self.is_ssd = model_config.is_ssd @@ -396,6 +398,9 @@ class StableDiffusion: else: Pipe = StableDiffusionAdapterPipeline extra_args['adapter'] = self.adapter + elif isinstance(self.adapter, ReferenceAdapter): + # pass the noise scheduler to the adapter + self.adapter.noise_scheduler = noise_scheduler else: if self.is_xl: extra_args['add_watermarker'] = False @@ -478,6 +483,12 @@ class StableDiffusion: transforms.ToTensor(), ]) validation_image = transform(validation_image) + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) if self.network is not None: self.network.multiplier = gen_config.network_multiplier @@ -594,6 +605,9 @@ class StableDiffusion: gen_config.save_image(img, i) + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + # clear pipeline and cache to reduce vram usage del pipeline if refiner_pipeline is not None: @@ -1455,6 +1469,10 @@ class StableDiffusion: elif isinstance(self.adapter, ClipVisionAdapter): requires_grad = self.adapter.embedder.training adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device else: raise ValueError(f"Unknown adapter type: {type(self.adapter)}") self.device_state['adapter'] = {