mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-14 15:07:22 +00:00
Added reference adapters, many bug fixes, more ip adapter work and customizability
This commit is contained in:
@@ -15,6 +15,7 @@ from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight, add_all_snr_to_noise_scheduler, \
|
||||
apply_learnable_snr_gos, LearnableSNRGamma
|
||||
@@ -285,9 +286,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if torch.isnan(prior_loss).any():
|
||||
raise ValueError("Prior loss is nan")
|
||||
|
||||
# prior_loss = prior_loss.mean([1, 2, 3])
|
||||
loss = loss + prior_loss
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if prior_loss is not None:
|
||||
loss = loss + prior_loss
|
||||
|
||||
if not self.train_config.train_turbo:
|
||||
if self.train_config.learnable_snr_gos:
|
||||
@@ -623,11 +626,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.network is not None:
|
||||
was_network_active = self.network.is_active
|
||||
self.network.is_active = False
|
||||
is_ip_adapter = False
|
||||
was_ip_adapter_active = False
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter):
|
||||
is_ip_adapter = True
|
||||
was_ip_adapter_active = self.adapter.is_active
|
||||
can_disable_adapter = False
|
||||
was_adapter_active = False
|
||||
if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter)):
|
||||
can_disable_adapter = True
|
||||
was_adapter_active = self.adapter.is_active
|
||||
self.adapter.is_active = False
|
||||
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
@@ -666,8 +669,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs:
|
||||
del pred_kwargs['down_intrablock_additional_residuals']
|
||||
|
||||
if is_ip_adapter:
|
||||
self.adapter.is_active = was_ip_adapter_active
|
||||
if can_disable_adapter:
|
||||
self.adapter.is_active = was_adapter_active
|
||||
# restore network
|
||||
# self.network.multiplier = network_weight_list
|
||||
if self.network is not None:
|
||||
@@ -950,12 +953,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
if has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
elif is_reg:
|
||||
if is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
clip_images = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
@@ -967,6 +965,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
drop=True,
|
||||
is_training=True
|
||||
)
|
||||
elif has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
|
||||
@@ -978,12 +981,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass in our scheduler
|
||||
self.adapter.noise_scheduler = self.lr_scheduler
|
||||
if has_clip_image or has_adapter_img:
|
||||
img_to_use = clip_images if has_clip_image else adapter_images
|
||||
# currently 0-1 needs to be -1 to 1
|
||||
reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype)
|
||||
self.adapter.set_reference_images(reference_images)
|
||||
self.adapter.noise_scheduler = self.sd.noise_scheduler
|
||||
elif is_reg:
|
||||
self.adapter.set_blank_reference_images(noisy_latents.shape[0])
|
||||
else:
|
||||
self.adapter.set_reference_images(None)
|
||||
|
||||
prior_pred = None
|
||||
|
||||
do_reg_prior = False
|
||||
if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# we are doing a reg image and we have a network or adapter
|
||||
do_reg_prior = True
|
||||
# if is_reg and (self.network is not None or self.adapter is not None):
|
||||
# # we are doing a reg image and we have a network or adapter
|
||||
# do_reg_prior = True
|
||||
|
||||
do_inverted_masked_prior = False
|
||||
if self.train_config.inverted_mask_prior and batch.mask_tensor is not None:
|
||||
|
||||
@@ -31,6 +31,7 @@ from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
|
||||
load_ip_adapter_model
|
||||
@@ -140,7 +141,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
@@ -771,8 +772,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
num_train_timesteps, device=self.device_torch
|
||||
)
|
||||
|
||||
content_or_style = self.train_config.content_or_style
|
||||
if is_reg:
|
||||
content_or_style = self.train_config.content_or_style_reg
|
||||
|
||||
# if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content':
|
||||
if self.train_config.content_or_style in ['style', 'content']:
|
||||
if content_or_style in ['style', 'content']:
|
||||
# this is from diffusers training code
|
||||
# Cubic sampling for favoring later or earlier timesteps
|
||||
# For more details about why cubic sampling is used for content / structure,
|
||||
@@ -783,9 +788,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
orig_timesteps = torch.rand((batch_size,), device=latents.device)
|
||||
|
||||
if self.train_config.content_or_style == 'content':
|
||||
if content_or_style == 'content':
|
||||
timestep_indices = orig_timesteps ** 3 * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
elif self.train_config.content_or_style == 'style':
|
||||
elif content_or_style == 'style':
|
||||
timestep_indices = (1 - orig_timesteps ** 3) * self.sd.noise_scheduler.config['num_train_timesteps']
|
||||
|
||||
timestep_indices = value_map(
|
||||
@@ -800,7 +805,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
max_noise_steps - 1
|
||||
)
|
||||
|
||||
elif self.train_config.content_or_style == 'balanced':
|
||||
elif content_or_style == 'balanced':
|
||||
if min_noise_steps == max_noise_steps:
|
||||
timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps
|
||||
else:
|
||||
@@ -813,7 +818,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
timestep_indices = timestep_indices.long()
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {self.train_config.content_or_style}")
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
@@ -824,9 +829,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
height=latents.shape[2],
|
||||
width=latents.shape[3],
|
||||
batch_size=batch_size,
|
||||
noise_offset=self.train_config.noise_offset
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents
|
||||
# this will negate any noise offsets
|
||||
if self.train_config.dynamic_noise_offset and not is_reg:
|
||||
latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2
|
||||
# subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel
|
||||
noise = noise + latents_channel_mean
|
||||
|
||||
if self.train_config.loss_target == 'differential_noise':
|
||||
differential = latents - unaugmented_latents
|
||||
# add noise to differential
|
||||
@@ -912,6 +924,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
suffix = 't2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
suffix = 'clip'
|
||||
elif self.adapter_config.type == 'reference':
|
||||
suffix = 'ref'
|
||||
else:
|
||||
suffix = 'ip'
|
||||
adapter_name = self.name
|
||||
@@ -943,6 +957,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
elif self.adapter_config.type == 'reference':
|
||||
self.adapter = ReferenceAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
@@ -1441,6 +1460,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
did_first_flush = True
|
||||
# flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
self.adapter.clear_memory()
|
||||
|
||||
with torch.no_grad():
|
||||
# torch.cuda.empty_cache()
|
||||
|
||||
@@ -31,12 +31,18 @@ def get_mean_std(tensor):
|
||||
def adain(content_features, style_features):
|
||||
# Assumes that the content and style features are of shape (batch_size, channels, width, height)
|
||||
|
||||
dims = [2, 3]
|
||||
if len(content_features.shape) == 3:
|
||||
# content_features = content_features.unsqueeze(0)
|
||||
# style_features = style_features.unsqueeze(0)
|
||||
dims = [1]
|
||||
|
||||
# Step 1: Calculate mean and variance of content features
|
||||
content_mean, content_var = torch.mean(content_features, dim=[2, 3], keepdim=True), torch.var(content_features,
|
||||
dim=[2, 3],
|
||||
content_mean, content_var = torch.mean(content_features, dim=dims, keepdim=True), torch.var(content_features,
|
||||
dim=dims,
|
||||
keepdim=True)
|
||||
# Step 2: Calculate mean and variance of style features
|
||||
style_mean, style_var = torch.mean(style_features, dim=[2, 3], keepdim=True), torch.var(style_features, dim=[2, 3],
|
||||
style_mean, style_var = torch.mean(style_features, dim=dims, keepdim=True), torch.var(style_features, dim=dims,
|
||||
keepdim=True)
|
||||
|
||||
# Step 3: Normalize content features
|
||||
|
||||
@@ -178,6 +178,7 @@ class TrainConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.noise_scheduler = kwargs.get('noise_scheduler', 'ddpm')
|
||||
self.content_or_style: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
|
||||
self.content_or_style_reg: ContentOrStyleType = kwargs.get('content_or_style', 'balanced')
|
||||
self.steps: int = kwargs.get('steps', 1000)
|
||||
self.lr = kwargs.get('lr', 1e-6)
|
||||
self.unet_lr = kwargs.get('unet_lr', self.lr)
|
||||
@@ -268,6 +269,8 @@ class TrainConfig:
|
||||
if self.train_turbo and not self.noise_scheduler.startswith("euler"):
|
||||
raise ValueError(f"train_turbo is only supported with euler and wuler_a noise schedulers")
|
||||
|
||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -232,6 +232,10 @@ class IPAdapter(torch.nn.Module):
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
if self.config.image_encoder_arch == 'safe':
|
||||
embedding_dim = self.config.safe_channels
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 16, 1280]) from checkpoint, the shape in current model is torch.Size([1, 16, 2048]).
|
||||
# size mismatch for latents: copying a param with shape torch.Size([1, 32, 2048]) from checkpoint, the shape in current model is torch.Size([1, 16, 1280])
|
||||
# ip-adapter-plus
|
||||
@@ -241,7 +245,7 @@ class IPAdapter(torch.nn.Module):
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
embedding_dim=self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else self.image_encoder.config.hidden_sizes[-1],
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
411
toolkit/reference_adapter.py
Normal file
411
toolkit/reference_adapter.py
Normal file
@@ -0,0 +1,411 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.basic import adain
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
DDPMScheduler,
|
||||
)
|
||||
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection
|
||||
)
|
||||
from toolkit.models.size_agnostic_feature_encoder import SAFEImageProcessor, SAFEVisionModel
|
||||
|
||||
from transformers import ViTHybridImageProcessor, ViTHybridForImageClassification
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ReferenceAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||
The context length of the image features.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, adapter=None):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.ref_net = nn.Linear(hidden_size, hidden_size)
|
||||
self.blend = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self._memory = None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
if self.adapter_ref().is_active:
|
||||
if self.adapter_ref().reference_mode == "write":
|
||||
# write_mode
|
||||
memory_ref = self.ref_net(hidden_states)
|
||||
self._memory = memory_ref
|
||||
elif self.adapter_ref().reference_mode == "read":
|
||||
# read_mode
|
||||
if self._memory is None:
|
||||
print("Warning: no memory to read from")
|
||||
else:
|
||||
|
||||
saved_hidden_states = self._memory
|
||||
try:
|
||||
new_hidden_states = saved_hidden_states
|
||||
blend = self.blend
|
||||
# expand the blend buyt keep dim 0 the same (batch)
|
||||
while blend.ndim < new_hidden_states.ndim:
|
||||
blend = blend.unsqueeze(0)
|
||||
# expand batch
|
||||
blend = torch.cat([blend] * new_hidden_states.shape[0], dim=0)
|
||||
hidden_states = blend * new_hidden_states + (1 - blend) * hidden_states
|
||||
except Exception as e:
|
||||
raise Exception(f"Error blending: {e}")
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ReferenceAdapter(torch.nn.Module):
|
||||
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.reference_mode = "read"
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
self._reference_images = None
|
||||
self._reference_latents = None
|
||||
self.has_memory = False
|
||||
|
||||
self.noise_scheduler: Union[DDPMScheduler, EulerDiscreteScheduler] = None
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
# layer_name = name.split(".processor")[0]
|
||||
# weights = {
|
||||
# "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
# "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
# }
|
||||
|
||||
attn_procs[name] = ReferenceAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self
|
||||
)
|
||||
# attn_procs[name].load_state_dict(weights)
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
sd.adapter = self
|
||||
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
|
||||
self.adapter_modules = adapter_modules
|
||||
# load the weights if we have some
|
||||
if self.config.name_or_path:
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
self.config.name_or_path,
|
||||
device='cpu',
|
||||
dtype=sd.torch_dtype
|
||||
)
|
||||
self.load_state_dict(loaded_state_dict)
|
||||
|
||||
self.set_scale(1.0)
|
||||
self.attach()
|
||||
self.to(self.device, self.sd_ref().torch_dtype)
|
||||
|
||||
# if self.config.train_image_encoder:
|
||||
# self.image_encoder.train()
|
||||
# self.image_encoder.requires_grad_(True)
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
# self.image_encoder.to(*args, **kwargs)
|
||||
# self.image_proj_model.to(*args, **kwargs)
|
||||
self.adapter_modules.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def load_reference_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
reference_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
reference_layers.load_state_dict(state_dict["reference_adapter"])
|
||||
|
||||
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
||||
# self.load_ip_adapter(state_dict)
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
state_dict["reference_adapter"] = self.adapter_modules.state_dict()
|
||||
return state_dict
|
||||
|
||||
def get_scale(self):
|
||||
return self.current_scale
|
||||
|
||||
def set_reference_images(self, reference_images: Optional[torch.Tensor]):
|
||||
self._reference_images = reference_images.clone().detach()
|
||||
self._reference_latents = None
|
||||
self.clear_memory()
|
||||
|
||||
def set_blank_reference_images(self, batch_size):
|
||||
self._reference_images = torch.zeros((batch_size, 3, 512, 512), device=self.device, dtype=self.sd_ref().torch_dtype)
|
||||
self._reference_latents = torch.zeros((batch_size, 4, 64, 64), device=self.device, dtype=self.sd_ref().torch_dtype)
|
||||
self.clear_memory()
|
||||
|
||||
|
||||
def set_scale(self, scale):
|
||||
self.current_scale = scale
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, ReferenceAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
|
||||
def attach(self):
|
||||
unet = self.sd_ref().unet
|
||||
self._original_unet_forward = unet.forward
|
||||
unet.forward = lambda *args, **kwargs: self.unet_forward(*args, **kwargs)
|
||||
if self.sd_ref().network is not None:
|
||||
# set network to not merge in
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
def unet_forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
|
||||
skip = False
|
||||
if self._reference_images is None and self._reference_latents is None:
|
||||
skip = True
|
||||
if not self.is_active:
|
||||
skip = True
|
||||
|
||||
if self.has_memory:
|
||||
skip = True
|
||||
|
||||
if not skip:
|
||||
if self.sd_ref().network is not None:
|
||||
self.sd_ref().network.is_active = True
|
||||
if self.sd_ref().network.is_merged_in:
|
||||
raise ValueError("network is merged in, but we are not supposed to be merged in")
|
||||
# send it through our forward first
|
||||
self.forward(sample, timestep, encoder_hidden_states, *args, **kwargs)
|
||||
|
||||
if self.sd_ref().network is not None:
|
||||
self.sd_ref().network.is_active = False
|
||||
|
||||
# Send it through the original unet forward
|
||||
return self._original_unet_forward(sample, timestep, encoder_hidden_states, args, **kwargs)
|
||||
|
||||
|
||||
# use drop for prompt dropout, or negatives
|
||||
def forward(self, sample, timestep, encoder_hidden_states, *args, **kwargs):
|
||||
if not self.noise_scheduler:
|
||||
raise ValueError("noise scheduler not set")
|
||||
if not self.is_active or (self._reference_images is None and self._reference_latents is None):
|
||||
raise ValueError("reference adapter not active or no reference images set")
|
||||
# todo may need to handle cfg?
|
||||
self.reference_mode = "write"
|
||||
|
||||
if self._reference_latents is None:
|
||||
self._reference_latents = self.sd_ref().encode_images(self._reference_images.to(
|
||||
self.device, self.sd_ref().torch_dtype
|
||||
)).detach()
|
||||
# create a sample from our reference images
|
||||
reference_latents = self._reference_latents.clone().detach().to(self.device, self.sd_ref().torch_dtype)
|
||||
# if our num of samples are half of incoming, we are doing cfg. Zero out the first half (unconditional)
|
||||
if reference_latents.shape[0] * 2 == sample.shape[0]:
|
||||
# we are doing cfg
|
||||
# Unconditional goes first
|
||||
reference_latents = torch.cat([torch.zeros_like(reference_latents), reference_latents], dim=0).detach()
|
||||
|
||||
# resize it so reference_latents will fit inside sample in the center
|
||||
width_scale = sample.shape[2] / reference_latents.shape[2]
|
||||
height_scale = sample.shape[3] / reference_latents.shape[3]
|
||||
scale = min(width_scale, height_scale)
|
||||
# resize the reference latents
|
||||
|
||||
mode = "bilinear" if scale > 1.0 else "bicubic"
|
||||
|
||||
reference_latents = F.interpolate(
|
||||
reference_latents,
|
||||
size=(int(reference_latents.shape[2] * scale), int(reference_latents.shape[3] * scale)),
|
||||
mode=mode,
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# add 0 padding if needed
|
||||
width_pad = (sample.shape[2] - reference_latents.shape[2]) / 2
|
||||
height_pad = (sample.shape[3] - reference_latents.shape[3]) / 2
|
||||
reference_latents = F.pad(
|
||||
reference_latents,
|
||||
(math.floor(width_pad), math.floor(width_pad), math.ceil(height_pad), math.ceil(height_pad)),
|
||||
mode="constant",
|
||||
value=0
|
||||
)
|
||||
|
||||
# resize again just to make sure it is exact same size
|
||||
reference_latents = F.interpolate(
|
||||
reference_latents,
|
||||
size=(sample.shape[2], sample.shape[3]),
|
||||
mode="bicubic",
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# todo maybe add same noise to the sample? For now we will send it through with no noise
|
||||
# sample_imgs = self.noise_scheduler.add_noise(sample_imgs, timestep)
|
||||
self._original_unet_forward(reference_latents, timestep, encoder_hidden_states, *args, **kwargs)
|
||||
self.reference_mode = "read"
|
||||
self.has_memory = True
|
||||
return None
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
for attn_processor in self.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
# yield from self.image_proj_model.parameters(recurse)
|
||||
# if self.config.train_image_encoder:
|
||||
# yield from self.image_encoder.parameters(recurse)
|
||||
# if self.config.train_image_encoder:
|
||||
# yield from self.image_encoder.parameters(recurse)
|
||||
# self.image_encoder.train()
|
||||
# else:
|
||||
# for attn_processor in self.adapter_modules:
|
||||
# yield from attn_processor.parameters(recurse)
|
||||
# yield from self.image_proj_model.parameters(recurse)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
# self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["reference_adapter"], strict=strict)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.image_encoder.gradient_checkpointing = True
|
||||
|
||||
def clear_memory(self):
|
||||
for attn_processor in self.adapter_modules:
|
||||
if isinstance(attn_processor, ReferenceAttnProcessor2_0):
|
||||
attn_processor._memory = None
|
||||
self.has_memory = False
|
||||
@@ -27,6 +27,7 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
||||
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds
|
||||
from toolkit.reference_adapter import ReferenceAdapter
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
@@ -76,6 +77,7 @@ class BlankNetwork:
|
||||
self.multiplier = 1.0
|
||||
self.is_active = True
|
||||
self.is_merged_in = False
|
||||
self.can_merge_in = False
|
||||
|
||||
def __enter__(self):
|
||||
self.is_active = True
|
||||
@@ -134,7 +136,7 @@ class StableDiffusion:
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
|
||||
self.adapter: Union['T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
self.is_ssd = model_config.is_ssd
|
||||
@@ -396,6 +398,9 @@ class StableDiffusion:
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
elif isinstance(self.adapter, ReferenceAdapter):
|
||||
# pass the noise scheduler to the adapter
|
||||
self.adapter.noise_scheduler = noise_scheduler
|
||||
else:
|
||||
if self.is_xl:
|
||||
extra_args['add_watermarker'] = False
|
||||
@@ -478,6 +483,12 @@ class StableDiffusion:
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
if isinstance(self.adapter, ReferenceAdapter):
|
||||
# need -1 to 1
|
||||
validation_image = transforms.ToTensor()(validation_image)
|
||||
validation_image = validation_image * 2.0 - 1.0
|
||||
validation_image = validation_image.unsqueeze(0)
|
||||
self.adapter.set_reference_images(validation_image)
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
@@ -594,6 +605,9 @@ class StableDiffusion:
|
||||
|
||||
gen_config.save_image(img, i)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
||||
self.adapter.clear_memory()
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
if refiner_pipeline is not None:
|
||||
@@ -1455,6 +1469,10 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ReferenceAdapter):
|
||||
# todo update this!!
|
||||
requires_grad = True
|
||||
adapter_device = self.adapter.device
|
||||
else:
|
||||
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
|
||||
self.device_state['adapter'] = {
|
||||
|
||||
Reference in New Issue
Block a user