mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a clip vision adapter trainer. Only works for sd15 for now
This commit is contained in:
@@ -5,6 +5,7 @@ from diffusers import T2IAdapter
|
||||
import torch.functional as F
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
|
||||
@@ -504,6 +505,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
noise: torch.Tensor,
|
||||
**kwargs
|
||||
):
|
||||
# todo for embeddings, we need to run without trigger words
|
||||
was_unet_training = self.sd.unet.training
|
||||
was_network_active = False
|
||||
if self.network is not None:
|
||||
@@ -519,13 +521,28 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# do a prediction here so we can match its output with network multiplier set to 0.0
|
||||
with torch.no_grad():
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
embeds_to_use = conditional_embeds.clone().detach()
|
||||
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
|
||||
prompt_list = batch.get_caption_list()
|
||||
for idx, prompt in enumerate(prompt_list):
|
||||
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
|
||||
prompt_list[idx] = prompt
|
||||
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
|
||||
# dont use network on this
|
||||
# self.network.multiplier = 0.0
|
||||
self.sd.unet.eval()
|
||||
|
||||
prior_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
|
||||
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
|
||||
timestep=timesteps,
|
||||
guidance_scale=1.0,
|
||||
**pred_kwargs # adapter residuals in here
|
||||
@@ -666,6 +683,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.embedding:
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
@@ -745,6 +765,26 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
# encode clip adapter here so embeds are active for tokenizer
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
with self.timer('encode_clip_vision_embeds'):
|
||||
if has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
else:
|
||||
# just do a blank one
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
),
|
||||
is_training=True
|
||||
)
|
||||
# it will be injected into the tokenizer when called
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
@@ -912,6 +952,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
with self.timer('restore_embeddings'):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.embedding.restore_embeddings()
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
|
||||
with self.timer('restore_adapter'):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
self.adapter.restore_embeddings()
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': loss.item()}
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
from typing import Union, List
|
||||
from typing import Union, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
import torch.backends.cuda
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
@@ -138,7 +139,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
@@ -202,7 +203,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
|
||||
if self.embedding is not None:
|
||||
prompt = self.embedding.inject_embedding_to_prompt(
|
||||
prompt, add_if_not_present=False
|
||||
prompt, expand_token=True, add_if_not_present=False
|
||||
)
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
|
||||
prompt = self.adapter.inject_trigger_into_prompt(
|
||||
prompt, expand_token=True, add_if_not_present=False
|
||||
)
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
@@ -400,6 +405,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# add _lora to name
|
||||
if self.adapter_config.type == 't2i':
|
||||
adapter_name += '_t2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
adapter_name += '_clip'
|
||||
else:
|
||||
adapter_name += '_ip'
|
||||
|
||||
@@ -647,6 +654,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
prompt = self.adapter.inject_trigger_into_prompt(
|
||||
prompt,
|
||||
expand_token=True,
|
||||
add_if_not_present=not is_reg,
|
||||
)
|
||||
|
||||
# make sure trigger is in the prompts if not a regularization run
|
||||
if self.trigger_word is not None:
|
||||
prompt = self.sd.inject_trigger_into_prompt(
|
||||
@@ -840,7 +854,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def setup_adapter(self):
|
||||
# t2i adapter
|
||||
is_t2i = self.adapter_config.type == 't2i'
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
if self.adapter_config.type == 't2i':
|
||||
suffix = 't2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
suffix = 'clip'
|
||||
else:
|
||||
suffix = 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
@@ -865,6 +884,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
elif self.adapter_config.type == 'clip':
|
||||
self.adapter = ClipVisionAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
|
||||
331
toolkit/clip_vision_adapter.py
Normal file
331
toolkit/clip_vision_adapter.py
Normal file
@@ -0,0 +1,331 @@
|
||||
from typing import TYPE_CHECKING, Mapping, Any
|
||||
|
||||
import torch
|
||||
import weakref
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionModel
|
||||
)
|
||||
|
||||
from toolkit.resampler import Resampler
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Embedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_input_tokens: int = 50,
|
||||
input_dim: int = 1024,
|
||||
num_output_tokens: int = 8,
|
||||
output_dim: int = 768,
|
||||
mid_dim: int = 128,
|
||||
):
|
||||
super(Embedder, self).__init__()
|
||||
self.num_output_tokens = num_output_tokens
|
||||
self.num_input_tokens = num_input_tokens
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Convolutional layer to reduce channel dimension
|
||||
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
|
||||
|
||||
# GELU Activation
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
# Layer Normalization
|
||||
self.layer_norm = nn.LayerNorm(mid_dim)
|
||||
|
||||
# Adaptive pooling to change sequence length
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
|
||||
|
||||
# Linear layer for final transformation
|
||||
self.final_linear = nn.Linear(mid_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # Adjust for Conv1d
|
||||
x = self.conv(x)
|
||||
x = self.gelu(x)
|
||||
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
|
||||
x = self.adaptive_pool(x)
|
||||
x = x.permute(0, 2, 1) # Adjust back
|
||||
x = self.final_linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class ClipVisionAdapter(torch.nn.Module):
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.trigger = adapter_config.trigger
|
||||
self.trigger_class_name = adapter_config.trigger_class_name
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
# embedding stuff
|
||||
self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder]
|
||||
self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer]
|
||||
placeholder_tokens = [self.trigger]
|
||||
|
||||
# add dummy tokens for multi-vector
|
||||
additional_tokens = []
|
||||
for i in range(1, self.config.num_tokens):
|
||||
additional_tokens.append(f"{self.trigger}_{i}")
|
||||
placeholder_tokens += additional_tokens
|
||||
|
||||
# handle dual tokenizer
|
||||
self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [
|
||||
self.sd_ref().tokenizer]
|
||||
self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [
|
||||
self.sd_ref().text_encoder]
|
||||
|
||||
self.placeholder_token_ids = []
|
||||
self.embedding_tokens = []
|
||||
|
||||
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
|
||||
|
||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.config.num_tokens:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {self.trigger}. Please pass a different"
|
||||
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
|
||||
)
|
||||
|
||||
# Convert the initializer_token, placeholder_token to ids
|
||||
init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False)
|
||||
# if length of token ids is more than number of orm embedding tokens fill with *
|
||||
if len(init_token_ids) > self.config.num_tokens:
|
||||
init_token_ids = init_token_ids[:self.config.num_tokens]
|
||||
elif len(init_token_ids) < self.config.num_tokens:
|
||||
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
|
||||
init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids))
|
||||
|
||||
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
|
||||
self.placeholder_token_ids.append(placeholder_token_ids)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
with torch.no_grad():
|
||||
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
|
||||
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
|
||||
|
||||
# backup text encoder embeddings
|
||||
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
|
||||
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
self.config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if self.config.train_image_encoder:
|
||||
self.image_encoder.train()
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
# self.embedder = Embedder(
|
||||
# num_output_tokens=self.config.num_tokens,
|
||||
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
|
||||
# input_dim=self.image_encoder.config.hidden_size,
|
||||
# output_dim=sd.unet.config['cross_attention_dim'],
|
||||
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
dim = sd.unet.config['cross_attention_dim']
|
||||
self.embedder = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
|
||||
self.embedder.train()
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
|
||||
state_dict = {
|
||||
'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
}
|
||||
if self.config.train_image_encoder:
|
||||
state_dict['image_encoder'] = self.image_encoder.state_dict(
|
||||
*args, destination=destination, prefix=prefix,
|
||||
keep_vars=keep_vars)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
self.embedder.load_state_dict(state_dict["embedder"], strict=strict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||
|
||||
def parameters(self, *args, **kwargs):
|
||||
yield from self.embedder.parameters(*args, **kwargs)
|
||||
|
||||
def named_parameters(self, *args, **kwargs):
|
||||
yield from self.embedder.named_parameters(*args, **kwargs)
|
||||
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self, tensors_0_1: torch.Tensor, drop=False,
|
||||
is_training=False
|
||||
) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
|
||||
import torch
|
||||
|
||||
def set_vec(self, new_vector, text_encoder_idx=0):
|
||||
# Get the embedding layer
|
||||
embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings()
|
||||
|
||||
# Indices to replace in the embeddings
|
||||
indices_to_replace = self.placeholder_token_ids[text_encoder_idx]
|
||||
|
||||
# Replace the specified embeddings with new_vector
|
||||
for idx in indices_to_replace:
|
||||
vector_idx = idx - indices_to_replace[0]
|
||||
embedding_layer.weight[idx] = new_vector[vector_idx]
|
||||
|
||||
# adds it to the tokenizer
|
||||
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
image_prompt_embeds = self.embedder(clip_image_embeds)
|
||||
# todo add support for multiple batch sizes
|
||||
if image_prompt_embeds.shape[0] != 1:
|
||||
raise ValueError("Batch size must be 1 for embedder for now")
|
||||
|
||||
# output on sd1.5 is bs, num_tokens, 768
|
||||
if len(self.text_encoder_list) == 1:
|
||||
# add it to the text encoder
|
||||
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
|
||||
else:
|
||||
raise ValueError("Multiple text encoders not supported yet")
|
||||
# just a place to put a breakpoint
|
||||
pass
|
||||
|
||||
def restore_embeddings(self):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(
|
||||
self.text_encoder_list,
|
||||
self.tokenizer_list,
|
||||
self.orig_embeds_params,
|
||||
self.placeholder_token_ids
|
||||
):
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds[index_no_updates]
|
||||
# detach it all
|
||||
text_encoder.get_input_embeddings().weight.detach_()
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.image_encoder.gradient_checkpointing = True
|
||||
|
||||
def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
|
||||
output_prompt = prompt
|
||||
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||
default_replacements = ["[name]", "[trigger]"]
|
||||
|
||||
replace_with = embedding_tokens if expand_token else self.trigger
|
||||
if to_replace_list is None:
|
||||
to_replace_list = default_replacements
|
||||
else:
|
||||
to_replace_list += default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances == 0 and add_if_not_present:
|
||||
# add it to the beginning of the prompt
|
||||
output_prompt = replace_with + " " + output_prompt
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
|
||||
# reverses injection with class name. useful for normalizations
|
||||
def inject_trigger_class_name_into_prompt(self, prompt):
|
||||
output_prompt = prompt
|
||||
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
|
||||
|
||||
default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger]
|
||||
|
||||
replace_with = self.config.trigger_class_name
|
||||
to_replace_list = default_replacements
|
||||
|
||||
# remove duplicates
|
||||
to_replace_list = list(set(to_replace_list))
|
||||
|
||||
# replace them all
|
||||
for to_replace in to_replace_list:
|
||||
# replace it
|
||||
output_prompt = output_prompt.replace(to_replace, replace_with)
|
||||
|
||||
# see how many times replace_with is in the prompt
|
||||
num_instances = output_prompt.count(replace_with)
|
||||
|
||||
if num_instances > 1:
|
||||
print(
|
||||
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
|
||||
|
||||
return output_prompt
|
||||
@@ -14,6 +14,7 @@ SaveFormat = Literal['safetensors', 'diffusers']
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.guidance import GuidanceType
|
||||
|
||||
|
||||
class SaveConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.save_every: int = kwargs.get('save_every', 1000)
|
||||
@@ -47,7 +48,8 @@ class SampleConfig:
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.ext: ImgExt = kwargs.get('format', 'jpg')
|
||||
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
|
||||
self.refiner_start_at = kwargs.get('refiner_start_at',
|
||||
0.5) # step to start using refiner on sample if it exists
|
||||
|
||||
|
||||
class LormModuleSettingsConfig:
|
||||
@@ -130,7 +132,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
@@ -153,15 +155,9 @@ class AdapterConfig:
|
||||
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
|
||||
|
||||
|
||||
|
||||
class ClipTokenMakerConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
||||
self.num_tokens: int = kwargs.get('num_tokens', 8)
|
||||
|
||||
|
||||
|
||||
# clip vision
|
||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||
self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
@@ -401,7 +397,8 @@ class DatasetConfig:
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path',
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
|
||||
self.unconditional_path: str = kwargs.get('unconditional_path',
|
||||
None) # path where matching unconditional images are located
|
||||
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
|
||||
self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1
|
||||
self.poi: Union[str, None] = kwargs.get('poi',
|
||||
|
||||
160
toolkit/resampler.py
Normal file
160
toolkit/resampler.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
||||
# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
max_seq_len: int = 257, # CLIP tokens + CLS token
|
||||
apply_pos_emb: bool = False,
|
||||
num_latents_mean_pooled: int = 0,
|
||||
# number of latents derived from mean pooled representation of the sequence
|
||||
):
|
||||
super().__init__()
|
||||
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.to_latents_from_mean_pooled_seq = (
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * num_latents_mean_pooled),
|
||||
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
||||
)
|
||||
if num_latents_mean_pooled > 0
|
||||
else None
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pos_emb is not None:
|
||||
n, device = x.shape[1], x.device
|
||||
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
||||
x = x + pos_emb
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
if self.to_latents_from_mean_pooled_seq:
|
||||
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
||||
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
||||
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
||||
|
||||
def masked_mean(t, *, dim, mask=None):
|
||||
if mask is None:
|
||||
return t.mean(dim=dim)
|
||||
|
||||
denom = mask.sum(dim=dim, keepdim=True)
|
||||
mask = rearrange(mask, "b n -> b n 1")
|
||||
masked_t = t.masked_fill(~mask, 0.0)
|
||||
|
||||
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
||||
@@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
@@ -472,7 +473,7 @@ class StableDiffusion:
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
@@ -483,6 +484,12 @@ class StableDiffusion:
|
||||
torch.manual_seed(gen_config.seed)
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
# run through the adapter to saturate the embeds
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
@@ -496,8 +503,8 @@ class StableDiffusion:
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter,
|
||||
IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
@@ -1445,6 +1452,9 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
else:
|
||||
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
|
||||
self.device_state['adapter'] = {
|
||||
|
||||
Reference in New Issue
Block a user