Files
ai-toolkit/toolkit/reference_adapter.py

412 lines
16 KiB
Python

import math
import torch
import sys
from PIL import Image
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.basic import adain
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from diffusers import (
EulerDiscreteScheduler,
DDPMScheduler,
)
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection
)
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch.nn.functional as F
import torch.nn as nn
class ReferenceAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.ref_net = nn.Linear(hidden_size, hidden_size)
self.blend = nn.Parameter(torch.zeros(hidden_size))
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self._memory = None
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
if self.adapter_ref().is_active:
if self.adapter_ref().reference_mode == "write":
# write_mode
memory_ref = self.ref_net(hidden_states)
self._memory = memory_ref
elif self.adapter_ref().reference_mode == "read":
# read_mode
if self._memory is None:
print("Warning: no memory to read from")
else:
saved_hidden_states = self._memory
try:
new_hidden_states = saved_hidden_states
blend = self.blend
# expand the blend buyt keep dim 0 the same (batch)
while blend.ndim < new_hidden_states.ndim:
blend = blend.unsqueeze(0)
# expand batch
blend = torch.cat([blend] * new_hidden_states.shape[0], dim=0)
hidden_states = blend * new_hidden_states + (1 - blend) * hidden_states
except Exception as e:
raise Exception(f"Error blending: {e}")
return hidden_states
class ReferenceAdapter(torch.nn.Module):
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
super().__init__()
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
self.device = self.sd_ref().unet.device
self.reference_mode = "read"
self.current_scale = 1.0
self.is_active = True
self._reference_images = None
self._reference_latents = None
self.has_memory = False
self.noise_scheduler: Union[DDPMScheduler, EulerDiscreteScheduler] = None
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
for name in sd.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
# layer_name = name.split(".processor")[0]
# weights = {
# "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
# "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
# }
attn_procs[name] = ReferenceAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=self.config.num_tokens,
adapter=self
)
# attn_procs[name].load_state_dict(weights)
sd.unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
sd.adapter = self
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
self.adapter_modules = adapter_modules
# load the weights if we have some
if self.config.name_or_path:
loaded_state_dict = load_ip_adapter_model(
self.config.name_or_path,
device='cpu',
dtype=sd.torch_dtype
)
self.load_state_dict(loaded_state_dict)
self.set_scale(1.0)
self.attach()
self.to(self.device, self.sd_ref().torch_dtype)
# if self.config.train_image_encoder:
# self.image_encoder.train()
# self.image_encoder.requires_grad_(True)
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
# self.image_encoder.to(*args, **kwargs)
# self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
return self
def load_reference_adapter(self, state_dict: Union[OrderedDict, dict]):
reference_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
reference_layers.load_state_dict(state_dict["reference_adapter"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict)
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
state_dict["reference_adapter"] = self.adapter_modules.state_dict()
return state_dict
def get_scale(self):
return self.current_scale
def set_reference_images(self, reference_images: Optional[torch.Tensor]):
self._reference_images = reference_images.clone().detach()
self._reference_latents = None
self.clear_memory()
def set_blank_reference_images(self, batch_size):
self._reference_images = torch.zeros((batch_size, 3, 512, 512), device=self.device, dtype=self.sd_ref().torch_dtype)
self._reference_latents = torch.zeros((batch_size, 4, 64, 64), device=self.device, dtype=self.sd_ref().torch_dtype)
self.clear_memory()
def set_scale(self, scale):
self.current_scale = scale
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, ReferenceAttnProcessor2_0):
attn_processor.scale = scale
def attach(self):
unet = self.sd_ref().unet
self._original_unet_forward = unet.forward
unet.forward = lambda *args, **kwargs: self.unet_forward(*args, **kwargs)
if self.sd_ref().network is not None:
# set network to not merge in
self.sd_ref().network.can_merge_in = False
def unet_forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
skip = False
if self._reference_images is None and self._reference_latents is None:
skip = True
if not self.is_active:
skip = True
if self.has_memory:
skip = True
if not skip:
if self.sd_ref().network is not None:
self.sd_ref().network.is_active = True
if self.sd_ref().network.is_merged_in:
raise ValueError("network is merged in, but we are not supposed to be merged in")
# send it through our forward first
self.forward(sample, timestep, encoder_hidden_states, *args, **kwargs)
if self.sd_ref().network is not None:
self.sd_ref().network.is_active = False
# Send it through the original unet forward
return self._original_unet_forward(sample, timestep, encoder_hidden_states, args, **kwargs)
# use drop for prompt dropout, or negatives
def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
if not self.noise_scheduler:
raise ValueError("noise scheduler not set")
if not self.is_active or (self._reference_images is None and self._reference_latents is None):
raise ValueError("reference adapter not active or no reference images set")
# todo may need to handle cfg?
self.reference_mode = "write"
if self._reference_latents is None:
self._reference_latents = self.sd_ref().encode_images(self._reference_images.to(
self.device, self.sd_ref().torch_dtype
)).detach()
# create a sample from our reference images
reference_latents = self._reference_latents.clone().detach().to(self.device, self.sd_ref().torch_dtype)
# if our num of samples are half of incoming, we are doing cfg. Zero out the first half (unconditional)
if reference_latents.shape[0] * 2 == sample.shape[0]:
# we are doing cfg
# Unconditional goes first
reference_latents = torch.cat([torch.zeros_like(reference_latents), reference_latents], dim=0).detach()
# resize it so reference_latents will fit inside sample in the center
width_scale = sample.shape[2] / reference_latents.shape[2]
height_scale = sample.shape[3] / reference_latents.shape[3]
scale = min(width_scale, height_scale)
# resize the reference latents
mode = "bilinear" if scale > 1.0 else "bicubic"
reference_latents = F.interpolate(
reference_latents,
size=(int(reference_latents.shape[2] * scale), int(reference_latents.shape[3] * scale)),
mode=mode,
align_corners=False
)
# add 0 padding if needed
width_pad = (sample.shape[2] - reference_latents.shape[2]) / 2
height_pad = (sample.shape[3] - reference_latents.shape[3]) / 2
reference_latents = F.pad(
reference_latents,
(math.floor(width_pad), math.floor(width_pad), math.ceil(height_pad), math.ceil(height_pad)),
mode="constant",
value=0
)
# resize again just to make sure it is exact same size
reference_latents = F.interpolate(
reference_latents,
size=(sample.shape[2], sample.shape[3]),
mode="bicubic",
align_corners=False
)
# todo maybe add same noise to the sample? For now we will send it through with no noise
# sample_imgs = self.noise_scheduler.add_noise(sample_imgs, timestep)
self._original_unet_forward(reference_latents, timestep, encoder_hidden_states, *args, **kwargs)
self.reference_mode = "read"
self.has_memory = True
return None
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
for attn_processor in self.adapter_modules:
yield from attn_processor.parameters(recurse)
# yield from self.image_proj_model.parameters(recurse)
# if self.config.train_image_encoder:
# yield from self.image_encoder.parameters(recurse)
# if self.config.train_image_encoder:
# yield from self.image_encoder.parameters(recurse)
# self.image_encoder.train()
# else:
# for attn_processor in self.adapter_modules:
# yield from attn_processor.parameters(recurse)
# yield from self.image_proj_model.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
# self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["reference_adapter"], strict=strict)
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True
def clear_memory(self):
for attn_processor in self.adapter_modules:
if isinstance(attn_processor, ReferenceAttnProcessor2_0):
attn_processor._memory = None
self.has_memory = False