mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
344 lines
15 KiB
Python
344 lines
15 KiB
Python
import torch
|
|
import sys
|
|
|
|
from PIL import Image
|
|
from torch.nn import Parameter
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
|
|
from toolkit.paths import REPOS_ROOT
|
|
from toolkit.saving import load_ip_adapter_model
|
|
from toolkit.train_tools import get_torch_dtype
|
|
|
|
sys.path.append(REPOS_ROOT)
|
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
|
from collections import OrderedDict
|
|
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
|
AttnProcessor2_0
|
|
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
|
from ipadapter.ip_adapter.resampler import Resampler
|
|
from toolkit.config_modules import AdapterConfig
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
import weakref
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
|
|
from transformers import (
|
|
CLIPImageProcessor,
|
|
CLIPVisionModelWithProjection,
|
|
CLIPVisionModel
|
|
)
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
|
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
|
|
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
|
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
|
|
|
def __call__(
|
|
self,
|
|
attn,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
):
|
|
is_active = self.adapter_ref().is_active
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
# scaled_dot_product_attention expects attention_mask shape to be
|
|
# (batch, heads, source_length, target_length)
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
# will be none if disabled
|
|
if not is_active:
|
|
ip_hidden_states = None
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
else:
|
|
# get encoder_hidden_states, ip_hidden_states
|
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
|
encoder_hidden_states, ip_hidden_states = (
|
|
encoder_hidden_states[:, :end_pos, :],
|
|
encoder_hidden_states[:, end_pos:, :],
|
|
)
|
|
if attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
inner_dim = key.shape[-1]
|
|
head_dim = inner_dim // attn.heads
|
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
# will be none if disabled
|
|
if ip_hidden_states is not None:
|
|
# for ip-adapter
|
|
ip_key = self.to_k_ip(ip_hidden_states)
|
|
ip_value = self.to_v_ip(ip_hidden_states)
|
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
ip_hidden_states = F.scaled_dot_product_attention(
|
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
|
class IPAdapter(torch.nn.Module):
|
|
"""IP-Adapter"""
|
|
|
|
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
|
|
super().__init__()
|
|
self.config = adapter_config
|
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
|
try:
|
|
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
|
except EnvironmentError:
|
|
self.clip_image_processor = CLIPImageProcessor()
|
|
self.device = self.sd_ref().unet.device
|
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path,
|
|
ignore_mismatched_sizes=True)
|
|
self.current_scale = 1.0
|
|
self.is_active = True
|
|
if adapter_config.type == 'ip':
|
|
# ip-adapter
|
|
image_proj_model = ImageProjModel(
|
|
cross_attention_dim=sd.unet.config['cross_attention_dim'],
|
|
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
|
clip_extra_context_tokens=self.config.num_tokens, # usually 4
|
|
)
|
|
elif adapter_config.type == 'ip+':
|
|
heads = 12 if not sd.is_xl else 20
|
|
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
|
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
|
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
|
# ip-adapter-plus
|
|
image_proj_model = Resampler(
|
|
dim=dim,
|
|
depth=4,
|
|
dim_head=64,
|
|
heads=heads,
|
|
num_queries=self.config.num_tokens, # usually 16
|
|
embedding_dim=self.image_encoder.config.hidden_size,
|
|
output_dim=sd.unet.config['cross_attention_dim'],
|
|
ff_mult=4
|
|
)
|
|
else:
|
|
raise ValueError(f"unknown adapter type: {adapter_config.type}")
|
|
|
|
# init adapter modules
|
|
attn_procs = {}
|
|
unet_sd = sd.unet.state_dict()
|
|
for name in sd.unet.attn_processors.keys():
|
|
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
|
if name.startswith("mid_block"):
|
|
hidden_size = sd.unet.config['block_out_channels'][-1]
|
|
elif name.startswith("up_blocks"):
|
|
block_id = int(name[len("up_blocks.")])
|
|
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
|
|
elif name.startswith("down_blocks"):
|
|
block_id = int(name[len("down_blocks.")])
|
|
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
|
else:
|
|
# they didnt have this, but would lead to undefined below
|
|
raise ValueError(f"unknown attn processor name: {name}")
|
|
if cross_attention_dim is None:
|
|
attn_procs[name] = AttnProcessor2_0()
|
|
else:
|
|
layer_name = name.split(".processor")[0]
|
|
weights = {
|
|
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
|
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
|
}
|
|
|
|
attn_procs[name] = CustomIPAttentionProcessor(
|
|
hidden_size=hidden_size,
|
|
cross_attention_dim=cross_attention_dim,
|
|
scale=1.0,
|
|
num_tokens=self.config.num_tokens,
|
|
adapter=self
|
|
)
|
|
attn_procs[name].load_state_dict(weights)
|
|
sd.unet.set_attn_processor(attn_procs)
|
|
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
|
|
|
sd.adapter = self
|
|
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
|
|
self.image_proj_model = image_proj_model
|
|
self.adapter_modules = adapter_modules
|
|
# load the weights if we have some
|
|
if self.config.name_or_path:
|
|
loaded_state_dict = load_ip_adapter_model(
|
|
self.config.name_or_path,
|
|
device='cpu',
|
|
dtype=sd.torch_dtype
|
|
)
|
|
self.load_state_dict(loaded_state_dict)
|
|
|
|
self.set_scale(1.0)
|
|
|
|
def to(self, *args, **kwargs):
|
|
super().to(*args, **kwargs)
|
|
self.image_encoder.to(*args, **kwargs)
|
|
self.image_proj_model.to(*args, **kwargs)
|
|
self.adapter_modules.to(*args, **kwargs)
|
|
return self
|
|
|
|
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
|
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
|
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
|
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
|
self.image_encoder.load_state_dict(state_dict["image_encoder"])
|
|
|
|
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
|
# self.load_ip_adapter(state_dict)
|
|
|
|
def state_dict(self) -> OrderedDict:
|
|
state_dict = OrderedDict()
|
|
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
|
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
|
if self.config.train_image_encoder:
|
|
state_dict["image_encoder"] = self.image_encoder.state_dict()
|
|
return state_dict
|
|
|
|
def get_scale(self):
|
|
return self.current_scale
|
|
|
|
def set_scale(self, scale):
|
|
self.current_scale = scale
|
|
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
|
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
|
attn_processor.scale = scale
|
|
|
|
@torch.no_grad()
|
|
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
|
drop=False) -> torch.Tensor:
|
|
# todo: add support for sdxl
|
|
if isinstance(pil_image, Image.Image):
|
|
pil_image = [pil_image]
|
|
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
|
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
|
if drop:
|
|
clip_image = clip_image * 0
|
|
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
|
return clip_image_embeds
|
|
|
|
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
|
|
is_training=False) -> torch.Tensor:
|
|
with torch.no_grad():
|
|
# tensors should be 0-1
|
|
# todo: add support for sdxl
|
|
if tensors_0_1.ndim == 3:
|
|
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
|
# training tensors are 0 - 1
|
|
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
|
# if images are out of this range throw error
|
|
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
|
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
|
tensors_0_1.min(), tensors_0_1.max()
|
|
))
|
|
|
|
clip_image = self.clip_image_processor(
|
|
images=tensors_0_1,
|
|
return_tensors="pt",
|
|
do_resize=True,
|
|
do_rescale=False,
|
|
).pixel_values
|
|
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
|
if drop:
|
|
clip_image = clip_image * 0
|
|
with torch.set_grad_enabled(is_training):
|
|
if is_training:
|
|
self.image_encoder.train()
|
|
else:
|
|
self.image_encoder.eval()
|
|
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
|
clip_image_embeds = clip_output.hidden_states[-2]
|
|
return clip_image_embeds
|
|
|
|
# use drop for prompt dropout, or negatives
|
|
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
|
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
|
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
|
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
|
return embeddings
|
|
|
|
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
|
for attn_processor in self.adapter_modules:
|
|
yield from attn_processor.parameters(recurse)
|
|
yield from self.image_proj_model.parameters(recurse)
|
|
if self.config.train_image_encoder:
|
|
yield from self.image_encoder.parameters(recurse)
|
|
|
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
|
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
|
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
|
|
|
def enable_gradient_checkpointing(self):
|
|
self.image_encoder.gradient_checkpointing = True
|