mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added more functionality for ip adapters
This commit is contained in:
@@ -94,9 +94,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True)
|
||||
noise_pred = noise_pred * (noise_norm / noise_pred_norm)
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None:
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
prior_mask_multiplier = 1.0 - mask_multiplier
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
batch.file_items[0].dataset_config.mask_min_value,
|
||||
1.0,
|
||||
0.0,
|
||||
1.0
|
||||
)
|
||||
|
||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
||||
|
||||
# target_mask_multiplier = mask_multiplier
|
||||
# mask_multiplier = 1.0
|
||||
target = noise
|
||||
@@ -152,7 +161,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# multiply by our mask
|
||||
loss = loss * mask_multiplier
|
||||
|
||||
if self.train_config.inverted_mask_prior:
|
||||
prior_loss = None
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None:
|
||||
# to a loss to unmasked areas of the prior for unmasked regularization
|
||||
prior_loss = torch.nn.functional.mse_loss(
|
||||
prior_pred.float(),
|
||||
@@ -160,10 +170,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
reduction="none"
|
||||
)
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
loss = loss + prior_loss
|
||||
if torch.isnan(prior_loss).any():
|
||||
raise ValueError("Prior loss is nan")
|
||||
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if prior_loss is not None:
|
||||
loss = loss + prior_loss
|
||||
|
||||
if self.train_config.learnable_snr_gos:
|
||||
# add snr_gamma
|
||||
loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos)
|
||||
@@ -491,14 +507,25 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
was_unet_training = self.sd.unet.training
|
||||
was_network_active = False
|
||||
if self.network is not None:
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
is_ip_adapter = False
|
||||
was_ip_adapter_active = False
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
|
||||
is_ip_adapter = True
|
||||
was_ip_adapter_active = self.adapter.is_active
|
||||
self.adapter.is_active = False
|
||||
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# dont use network on this
|
||||
# self.network.multiplier = 0.0
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
self.sd.unet.eval()
|
||||
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
@@ -506,14 +533,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
)
|
||||
self.sd.unet.train()
|
||||
if was_unet_training:
|
||||
self.sd.unet.train()
|
||||
prior_pred = prior_pred.detach()
|
||||
# remove the residuals as we wont use them on prediction when matching control
|
||||
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_intrablock_additional_residuals']
|
||||
|
||||
if is_ip_adapter:
|
||||
self.adapter.is_active = was_ip_adapter_active
|
||||
# restore network
|
||||
# self.network.multiplier = network_weight_list
|
||||
self.network.is_active = was_network_active
|
||||
if self.network is not None:
|
||||
self.network.is_active = was_network_active
|
||||
return prior_pred
|
||||
|
||||
def before_unet_predict(self):
|
||||
@@ -752,19 +784,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction:
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
@@ -788,6 +807,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach())
|
||||
|
||||
prior_pred = None
|
||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg):
|
||||
with self.timer('prior predict'):
|
||||
prior_pred = self.get_prior_prediction(
|
||||
noisy_latents=noisy_latents,
|
||||
conditional_embeds=conditional_embeds,
|
||||
match_adapter_assist=match_adapter_assist,
|
||||
network_weight_list=network_weight_list,
|
||||
timesteps=timesteps,
|
||||
pred_kwargs=pred_kwargs,
|
||||
noise=noise,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
if batch.unconditional_latents is not None or self.do_guided_loss:
|
||||
|
||||
@@ -480,7 +480,12 @@ class ControlFileItemDTOMixin:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.control_path}")
|
||||
|
||||
if not self.full_size_control_images:
|
||||
if self.full_size_control_images:
|
||||
# we just scale them to 512x512:
|
||||
w, h = img.size
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
else:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
|
||||
@@ -12,7 +12,7 @@ from toolkit.train_tools import get_torch_dtype
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
@@ -27,10 +27,120 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
|
||||
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
is_active = self.adapter_ref().is_active
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
# will be none if disabled
|
||||
if not is_active:
|
||||
ip_hidden_states = None
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# will be none if disabled
|
||||
if ip_hidden_states is not None:
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
||||
@@ -44,6 +154,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
# ip-adapter
|
||||
image_proj_model = ImageProjModel(
|
||||
@@ -84,14 +196,29 @@ class IPAdapter(torch.nn.Module):
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
if adapter_config.type == 'ip':
|
||||
# ip-adapter
|
||||
num_tokens = 4
|
||||
elif adapter_config.type == 'ip+':
|
||||
# ip-adapter-plus
|
||||
num_tokens = 16
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {adapter_config.type}")
|
||||
|
||||
attn_procs[name] = CustomIPAttentionProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_tokens,
|
||||
adapter=self
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
@@ -132,13 +259,18 @@ class IPAdapter(torch.nn.Module):
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
return state_dict
|
||||
|
||||
def get_scale(self):
|
||||
return self.current_scale
|
||||
|
||||
def set_scale(self, scale):
|
||||
self.current_scale = scale
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor:
|
||||
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
||||
drop=False) -> torch.Tensor:
|
||||
# todo: add support for sdxl
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
@@ -191,4 +323,3 @@ class IPAdapter(torch.nn.Module):
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user