Added more functionality for ip adapters

This commit is contained in:
Jaret Burkett
2023-12-19 09:54:56 -07:00
parent b653906715
commit 82098e5d6e
3 changed files with 200 additions and 31 deletions

View File

@@ -480,7 +480,12 @@ class ControlFileItemDTOMixin:
print(f"Error: {e}")
print(f"Error loading image: {self.control_path}")
if not self.full_size_control_images:
if self.full_size_control_images:
# we just scale them to 512x512:
w, h = img.size
img = img.resize((512, 512), Image.BICUBIC)
else:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match

View File

@@ -12,7 +12,7 @@ from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig
@@ -27,10 +27,120 @@ from transformers import (
CLIPVisionModelWithProjection,
)
from diffusers.models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
import torch.nn.functional as F
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
self.adapter_ref: weakref.ref = weakref.ref(adapter)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
is_active = self.adapter_ref().is_active
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
# will be none if disabled
if not is_active:
ip_hidden_states = None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# will be none if disabled
if ip_hidden_states is not None:
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
@@ -44,6 +154,8 @@ class IPAdapter(torch.nn.Module):
self.clip_image_processor = CLIPImageProcessor()
self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
self.current_scale = 1.0
self.is_active = True
if adapter_config.type == 'ip':
# ip-adapter
image_proj_model = ImageProjModel(
@@ -84,14 +196,29 @@ class IPAdapter(torch.nn.Module):
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
attn_procs[name] = AttnProcessor2_0()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
if adapter_config.type == 'ip':
# ip-adapter
num_tokens = 4
elif adapter_config.type == 'ip+':
# ip-adapter-plus
num_tokens = 16
else:
raise ValueError(f"unknown adapter type: {adapter_config.type}")
attn_procs[name] = CustomIPAttentionProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_tokens,
adapter=self
)
attn_procs[name].load_state_dict(weights)
sd.unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
@@ -132,13 +259,18 @@ class IPAdapter(torch.nn.Module):
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
return state_dict
def get_scale(self):
return self.current_scale
def set_scale(self, scale):
self.current_scale = scale
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@torch.no_grad()
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor:
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
drop=False) -> torch.Tensor:
# todo: add support for sdxl
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
@@ -191,4 +323,3 @@ class IPAdapter(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)