mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
407 lines
18 KiB
Python
407 lines
18 KiB
Python
from typing import TYPE_CHECKING, Mapping, Any
|
|
|
|
import torch
|
|
import weakref
|
|
|
|
from toolkit.config_modules import AdapterConfig
|
|
from toolkit.models.clip_fusion import ZipperBlock
|
|
from toolkit.models.zipper_resampler import ZipperModule
|
|
from toolkit.prompt_utils import PromptEmbeds
|
|
from toolkit.train_tools import get_torch_dtype
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.stable_diffusion_model import StableDiffusion
|
|
|
|
from transformers import (
|
|
CLIPImageProcessor,
|
|
CLIPVisionModelWithProjection,
|
|
CLIPVisionModel
|
|
)
|
|
|
|
from toolkit.resampler import Resampler
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class Embedder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_input_tokens: int = 1,
|
|
input_dim: int = 1024,
|
|
num_output_tokens: int = 8,
|
|
output_dim: int = 768,
|
|
mid_dim: int = 1024
|
|
):
|
|
super(Embedder, self).__init__()
|
|
self.num_output_tokens = num_output_tokens
|
|
self.num_input_tokens = num_input_tokens
|
|
self.input_dim = input_dim
|
|
self.output_dim = output_dim
|
|
|
|
self.layer_norm = nn.LayerNorm(input_dim)
|
|
self.fc1 = nn.Linear(input_dim, mid_dim)
|
|
self.gelu = nn.GELU()
|
|
# self.fc2 = nn.Linear(mid_dim, mid_dim)
|
|
self.fc2 = nn.Linear(mid_dim, mid_dim)
|
|
|
|
self.fc2.weight.data.zero_()
|
|
|
|
self.layer_norm2 = nn.LayerNorm(mid_dim)
|
|
self.fc3 = nn.Linear(mid_dim, mid_dim)
|
|
self.gelu2 = nn.GELU()
|
|
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
|
|
|
# set the weights to 0
|
|
self.fc3.weight.data.zero_()
|
|
self.fc4.weight.data.zero_()
|
|
|
|
|
|
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
|
|
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
|
|
|
|
def forward(self, x):
|
|
if len(x.shape) == 2:
|
|
x = x.unsqueeze(1)
|
|
x = self.layer_norm(x)
|
|
x = self.fc1(x)
|
|
x = self.gelu(x)
|
|
x = self.fc2(x)
|
|
x = self.layer_norm2(x)
|
|
x = self.fc3(x)
|
|
x = self.gelu2(x)
|
|
x = self.fc4(x)
|
|
|
|
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
|
|
|
return x
|
|
|
|
|
|
class ClipVisionAdapter(torch.nn.Module):
|
|
def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig):
|
|
super().__init__()
|
|
self.config = adapter_config
|
|
self.trigger = adapter_config.trigger
|
|
self.trigger_class_name = adapter_config.trigger_class_name
|
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
|
# embedding stuff
|
|
self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder]
|
|
self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer]
|
|
placeholder_tokens = [self.trigger]
|
|
|
|
# add dummy tokens for multi-vector
|
|
additional_tokens = []
|
|
for i in range(1, self.config.num_tokens):
|
|
additional_tokens.append(f"{self.trigger}_{i}")
|
|
placeholder_tokens += additional_tokens
|
|
|
|
# handle dual tokenizer
|
|
self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [
|
|
self.sd_ref().tokenizer]
|
|
self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [
|
|
self.sd_ref().text_encoder]
|
|
|
|
self.placeholder_token_ids = []
|
|
self.embedding_tokens = []
|
|
|
|
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
|
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
|
|
|
|
|
|
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
|
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
|
if num_added_tokens != self.config.num_tokens:
|
|
raise ValueError(
|
|
f"The tokenizer already contains the token {self.trigger}. Please pass a different"
|
|
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
|
|
)
|
|
|
|
# Convert the initializer_token, placeholder_token to ids
|
|
init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False)
|
|
# if length of token ids is more than number of orm embedding tokens fill with *
|
|
if len(init_token_ids) > self.config.num_tokens:
|
|
init_token_ids = init_token_ids[:self.config.num_tokens]
|
|
elif len(init_token_ids) < self.config.num_tokens:
|
|
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
|
|
init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids))
|
|
|
|
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
|
|
self.placeholder_token_ids.append(placeholder_token_ids)
|
|
|
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
|
|
|
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
|
with torch.no_grad():
|
|
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
|
|
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
|
|
|
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
|
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
|
|
|
|
# backup text encoder embeddings
|
|
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
|
|
|
|
try:
|
|
self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
|
|
except EnvironmentError:
|
|
self.clip_image_processor = CLIPImageProcessor()
|
|
self.device = self.sd_ref().unet.device
|
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
|
self.config.image_encoder_path,
|
|
ignore_mismatched_sizes=True
|
|
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
|
if self.config.train_image_encoder:
|
|
self.image_encoder.train()
|
|
else:
|
|
self.image_encoder.eval()
|
|
|
|
# max_seq_len = CLIP tokens + CLS token
|
|
image_encoder_state_dict = self.image_encoder.state_dict()
|
|
in_tokens = 257
|
|
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
|
# clip
|
|
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
|
|
|
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
|
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
|
else:
|
|
embedding_dim = self.image_encoder.config.target_hidden_size
|
|
|
|
if self.config.clip_layer == 'image_embeds':
|
|
in_tokens = 1
|
|
embedding_dim = self.image_encoder.config.projection_dim
|
|
|
|
self.embedder = Embedder(
|
|
num_output_tokens=self.config.num_tokens,
|
|
num_input_tokens=in_tokens,
|
|
input_dim=embedding_dim,
|
|
output_dim=self.sd_ref().unet.config['cross_attention_dim'],
|
|
mid_dim=embedding_dim * self.config.num_tokens,
|
|
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
|
|
|
self.embedder.train()
|
|
|
|
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
|
state_dict = {
|
|
'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
|
}
|
|
if self.config.train_image_encoder:
|
|
state_dict['image_encoder'] = self.image_encoder.state_dict(
|
|
*args, destination=destination, prefix=prefix,
|
|
keep_vars=keep_vars)
|
|
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
|
self.embedder.load_state_dict(state_dict["embedder"], strict=strict)
|
|
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
|
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
|
|
|
def parameters(self, *args, **kwargs):
|
|
yield from self.embedder.parameters(*args, **kwargs)
|
|
|
|
def named_parameters(self, *args, **kwargs):
|
|
yield from self.embedder.named_parameters(*args, **kwargs)
|
|
|
|
def get_clip_image_embeds_from_tensors(
|
|
self, tensors_0_1: torch.Tensor, drop=False,
|
|
is_training=False,
|
|
has_been_preprocessed=False
|
|
) -> torch.Tensor:
|
|
with torch.no_grad():
|
|
if not has_been_preprocessed:
|
|
# tensors should be 0-1
|
|
if tensors_0_1.ndim == 3:
|
|
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
|
# training tensors are 0 - 1
|
|
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
|
|
|
# if images are out of this range throw error
|
|
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
|
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
|
tensors_0_1.min(), tensors_0_1.max()
|
|
))
|
|
# unconditional
|
|
if drop:
|
|
if self.clip_noise_zero:
|
|
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
|
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
|
dtype=get_torch_dtype(self.sd_ref().dtype))
|
|
tensors_0_1 = tensors_0_1 * noise_scale
|
|
else:
|
|
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
|
# tensors_0_1 = tensors_0_1 * 0
|
|
clip_image = self.clip_image_processor(
|
|
images=tensors_0_1,
|
|
return_tensors="pt",
|
|
do_resize=True,
|
|
do_rescale=False,
|
|
).pixel_values
|
|
else:
|
|
if drop:
|
|
# scale the noise down
|
|
if self.clip_noise_zero:
|
|
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
|
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
|
dtype=get_torch_dtype(self.sd_ref().dtype))
|
|
tensors_0_1 = tensors_0_1 * noise_scale
|
|
else:
|
|
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
|
# tensors_0_1 = tensors_0_1 * 0
|
|
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
|
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
|
).detach()
|
|
std = torch.tensor(self.clip_image_processor.image_std).to(
|
|
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
|
).detach()
|
|
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
|
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
|
|
|
else:
|
|
clip_image = tensors_0_1
|
|
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
|
with torch.set_grad_enabled(is_training):
|
|
if is_training:
|
|
self.image_encoder.train()
|
|
else:
|
|
self.image_encoder.eval()
|
|
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
|
|
|
if self.config.clip_layer == 'penultimate_hidden_states':
|
|
# they skip last layer for ip+
|
|
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
|
clip_image_embeds = clip_output.hidden_states[-2]
|
|
elif self.config.clip_layer == 'last_hidden_state':
|
|
clip_image_embeds = clip_output.hidden_states[-1]
|
|
else:
|
|
clip_image_embeds = clip_output.image_embeds
|
|
return clip_image_embeds
|
|
|
|
import torch
|
|
|
|
def set_vec(self, new_vector, text_encoder_idx=0):
|
|
# Get the embedding layer
|
|
embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings()
|
|
|
|
# Indices to replace in the embeddings
|
|
indices_to_replace = self.placeholder_token_ids[text_encoder_idx]
|
|
|
|
# Replace the specified embeddings with new_vector
|
|
for idx in indices_to_replace:
|
|
vector_idx = idx - indices_to_replace[0]
|
|
embedding_layer.weight[idx] = new_vector[vector_idx]
|
|
|
|
# adds it to the tokenizer
|
|
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
|
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
|
if clip_image_embeds.ndim == 2:
|
|
# expand the token dimension
|
|
clip_image_embeds = clip_image_embeds.unsqueeze(1)
|
|
image_prompt_embeds = self.embedder(clip_image_embeds)
|
|
# todo add support for multiple batch sizes
|
|
if image_prompt_embeds.shape[0] != 1:
|
|
raise ValueError("Batch size must be 1 for embedder for now")
|
|
|
|
# output on sd1.5 is bs, num_tokens, 768
|
|
if len(self.text_encoder_list) == 1:
|
|
# add it to the text encoder
|
|
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
|
|
elif len(self.text_encoder_list) == 2:
|
|
if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \
|
|
image_prompt_embeds.shape[2]:
|
|
raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes")
|
|
# sdxl variants
|
|
# image_prompt_embeds = 2048
|
|
# te1 = 768
|
|
# te2 = 1280
|
|
te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size]
|
|
te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:]
|
|
self.set_vec(te1_embeds[0], text_encoder_idx=0)
|
|
self.set_vec(te2_embeds[0], text_encoder_idx=1)
|
|
else:
|
|
|
|
raise ValueError("Unsupported number of text encoders")
|
|
# just a place to put a breakpoint
|
|
pass
|
|
|
|
def restore_embeddings(self):
|
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
|
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(
|
|
self.text_encoder_list,
|
|
self.tokenizer_list,
|
|
self.orig_embeds_params,
|
|
self.placeholder_token_ids
|
|
):
|
|
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
|
index_no_updates[
|
|
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
|
with torch.no_grad():
|
|
text_encoder.get_input_embeddings().weight[
|
|
index_no_updates
|
|
] = orig_embeds[index_no_updates]
|
|
# detach it all
|
|
text_encoder.get_input_embeddings().weight.detach_()
|
|
|
|
def enable_gradient_checkpointing(self):
|
|
self.image_encoder.gradient_checkpointing = True
|
|
|
|
def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
|
output_prompt = prompt
|
|
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
|
default_replacements = ["[name]", "[trigger]"]
|
|
|
|
replace_with = embedding_tokens if expand_token else self.trigger
|
|
if to_replace_list is None:
|
|
to_replace_list = default_replacements
|
|
else:
|
|
to_replace_list += default_replacements
|
|
|
|
# remove duplicates
|
|
to_replace_list = list(set(to_replace_list))
|
|
|
|
# replace them all
|
|
for to_replace in to_replace_list:
|
|
# replace it
|
|
output_prompt = output_prompt.replace(to_replace, replace_with)
|
|
|
|
# see how many times replace_with is in the prompt
|
|
num_instances = output_prompt.count(replace_with)
|
|
|
|
if num_instances == 0 and add_if_not_present:
|
|
# add it to the beginning of the prompt
|
|
output_prompt = replace_with + " " + output_prompt
|
|
|
|
if num_instances > 1:
|
|
print(
|
|
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
|
|
|
return output_prompt
|
|
|
|
# reverses injection with class name. useful for normalizations
|
|
def inject_trigger_class_name_into_prompt(self, prompt):
|
|
output_prompt = prompt
|
|
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
|
|
|
default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger]
|
|
|
|
replace_with = self.config.trigger_class_name
|
|
to_replace_list = default_replacements
|
|
|
|
# remove duplicates
|
|
to_replace_list = list(set(to_replace_list))
|
|
|
|
# replace them all
|
|
for to_replace in to_replace_list:
|
|
# replace it
|
|
output_prompt = output_prompt.replace(to_replace, replace_with)
|
|
|
|
# see how many times replace_with is in the prompt
|
|
num_instances = output_prompt.count(replace_with)
|
|
|
|
if num_instances > 1:
|
|
print(
|
|
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
|
|
|
return output_prompt
|