Added a clip vision adapter trainer. Only works for sd15 for now

This commit is contained in:
Jaret Burkett
2023-12-24 13:26:04 -07:00
parent 0f8daa5612
commit 05ae95ca89
6 changed files with 586 additions and 20 deletions

View File

@@ -5,6 +5,7 @@ from diffusers import T2IAdapter
import torch.functional as F import torch.functional as F
from toolkit import train_tools from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
@@ -504,6 +505,7 @@ class SDTrainer(BaseSDTrainProcess):
noise: torch.Tensor, noise: torch.Tensor,
**kwargs **kwargs
): ):
# todo for embeddings, we need to run without trigger words
was_unet_training = self.sd.unet.training was_unet_training = self.sd.unet.training
was_network_active = False was_network_active = False
if self.network is not None: if self.network is not None:
@@ -519,13 +521,28 @@ class SDTrainer(BaseSDTrainProcess):
# do a prediction here so we can match its output with network multiplier set to 0.0 # do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad(): with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
embeds_to_use = conditional_embeds.clone().detach()
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
prompt_list = batch.get_caption_list()
for idx, prompt in enumerate(prompt_list):
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this # dont use network on this
# self.network.multiplier = 0.0 # self.network.multiplier = 0.0
self.sd.unet.eval() self.sd.unet.eval()
prior_pred = self.sd.predict_noise( prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps, timestep=timesteps,
guidance_scale=1.0, guidance_scale=1.0,
**pred_kwargs # adapter residuals in here **pred_kwargs # adapter residuals in here
@@ -666,6 +683,9 @@ class SDTrainer(BaseSDTrainProcess):
if self.embedding: if self.embedding:
grad_on_text_encoder = True grad_on_text_encoder = True
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time # have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None: if self.network is not None:
network = self.network network = self.network
@@ -745,6 +765,26 @@ class SDTrainer(BaseSDTrainProcess):
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network: with network:
# encode clip adapter here so embeds are active for tokenizer
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('encode_clip_vision_embeds'):
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True
)
else:
# just do a blank one
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
),
is_training=True
)
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
if grad_on_text_encoder: if grad_on_text_encoder:
with torch.set_grad_enabled(True): with torch.set_grad_enabled(True):
@@ -912,6 +952,10 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('restore_embeddings'): with self.timer('restore_embeddings'):
# Let's make sure we don't update any embedding weights besides the newly added token # Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings() self.embedding.restore_embeddings()
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('restore_adapter'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.adapter.restore_embeddings()
loss_dict = OrderedDict( loss_dict = OrderedDict(
{'loss': loss.item()} {'loss': loss.item()}

View File

@@ -5,7 +5,7 @@ import json
import shutil import shutil
from collections import OrderedDict from collections import OrderedDict
import os import os
from typing import Union, List from typing import Union, List, Optional
import numpy as np import numpy as np
import yaml import yaml
@@ -17,6 +17,7 @@ import torch
import torch.backends.cuda import torch.backends.cuda
from toolkit.basic import value_map from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding from toolkit.embedding import Embedding
@@ -138,7 +139,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one # to hold network if there is one
self.network: Union[Network, None] = None self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, None] = None self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None
self.embedding: Union[Embedding, None] = None self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train is_training_adapter = self.adapter_config is not None and self.adapter_config.train
@@ -202,7 +203,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here # ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
if self.embedding is not None: if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt( prompt = self.embedding.inject_embedding_to_prompt(
prompt, add_if_not_present=False prompt, expand_token=True, add_if_not_present=False
)
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
prompt = self.adapter.inject_trigger_into_prompt(
prompt, expand_token=True, add_if_not_present=False
) )
if self.trigger_word is not None: if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt( prompt = self.sd.inject_trigger_into_prompt(
@@ -400,6 +405,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# add _lora to name # add _lora to name
if self.adapter_config.type == 't2i': if self.adapter_config.type == 't2i':
adapter_name += '_t2i' adapter_name += '_t2i'
elif self.adapter_config.type == 'clip':
adapter_name += '_clip'
else: else:
adapter_name += '_ip' adapter_name += '_ip'
@@ -647,6 +654,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
add_if_not_present=not is_reg, add_if_not_present=not is_reg,
) )
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
prompt = self.adapter.inject_trigger_into_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure trigger is in the prompts if not a regularization run # make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None: if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt( prompt = self.sd.inject_trigger_into_prompt(
@@ -840,7 +854,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
def setup_adapter(self): def setup_adapter(self):
# t2i adapter # t2i adapter
is_t2i = self.adapter_config.type == 't2i' is_t2i = self.adapter_config.type == 't2i'
suffix = 't2i' if is_t2i else 'ip' if self.adapter_config.type == 't2i':
suffix = 't2i'
elif self.adapter_config.type == 'clip':
suffix = 'clip'
else:
suffix = 'ip'
adapter_name = self.name adapter_name = self.name
if self.network_config is not None: if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}" adapter_name = f"{adapter_name}_{suffix}"
@@ -865,6 +884,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
downscale_factor=self.adapter_config.downscale_factor, downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type, adapter_type=self.adapter_config.adapter_type,
) )
elif self.adapter_config.type == 'clip':
self.adapter = ClipVisionAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
else: else:
self.adapter = IPAdapter( self.adapter = IPAdapter(
sd=self.sd, sd=self.sd,

View File

@@ -0,0 +1,331 @@
from typing import TYPE_CHECKING, Mapping, Any
import torch
import weakref
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPVisionModel
)
from toolkit.resampler import Resampler
import torch.nn as nn
class Embedder(nn.Module):
def __init__(
self,
num_input_tokens: int = 50,
input_dim: int = 1024,
num_output_tokens: int = 8,
output_dim: int = 768,
mid_dim: int = 128,
):
super(Embedder, self).__init__()
self.num_output_tokens = num_output_tokens
self.num_input_tokens = num_input_tokens
self.input_dim = input_dim
self.output_dim = output_dim
# Convolutional layer to reduce channel dimension
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
# GELU Activation
self.gelu = nn.GELU()
# Layer Normalization
self.layer_norm = nn.LayerNorm(mid_dim)
# Adaptive pooling to change sequence length
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
# Linear layer for final transformation
self.final_linear = nn.Linear(mid_dim, output_dim)
def forward(self, x):
x = x.permute(0, 2, 1) # Adjust for Conv1d
x = self.conv(x)
x = self.gelu(x)
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
x = self.adaptive_pool(x)
x = x.permute(0, 2, 1) # Adjust back
x = self.final_linear(x)
return x
class ClipVisionAdapter(torch.nn.Module):
def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig):
super().__init__()
self.config = adapter_config
self.trigger = adapter_config.trigger
self.trigger_class_name = adapter_config.trigger_class_name
self.sd_ref: weakref.ref = weakref.ref(sd)
# embedding stuff
self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder]
self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer]
placeholder_tokens = [self.trigger]
# add dummy tokens for multi-vector
additional_tokens = []
for i in range(1, self.config.num_tokens):
additional_tokens.append(f"{self.trigger}_{i}")
placeholder_tokens += additional_tokens
# handle dual tokenizer
self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [
self.sd_ref().tokenizer]
self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [
self.sd_ref().text_encoder]
self.placeholder_token_ids = []
self.embedding_tokens = []
print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.config.num_tokens:
raise ValueError(
f"The tokenizer already contains the token {self.trigger}. Please pass a different"
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
)
# Convert the initializer_token, placeholder_token to ids
init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.config.num_tokens:
init_token_ids = init_token_ids[:self.config.num_tokens]
elif len(init_token_ids) < self.config.num_tokens:
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids))
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
self.placeholder_token_ids.append(placeholder_token_ids)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
# backup text encoder embeddings
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = CLIPImageProcessor()
self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
self.config.image_encoder_path,
ignore_mismatched_sizes=True
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if self.config.train_image_encoder:
self.image_encoder.train()
else:
self.image_encoder.eval()
# self.embedder = Embedder(
# num_output_tokens=self.config.num_tokens,
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
# input_dim=self.image_encoder.config.hidden_size,
# output_dim=sd.unet.config['cross_attention_dim'],
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
heads = 12 if not sd.is_xl else 20
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
dim = sd.unet.config['cross_attention_dim']
self.embedder = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
self.embedder.train()
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
state_dict = {
'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
}
if self.config.train_image_encoder:
state_dict['image_encoder'] = self.image_encoder.state_dict(
*args, destination=destination, prefix=prefix,
keep_vars=keep_vars)
return state_dict
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.embedder.load_state_dict(state_dict["embedder"], strict=strict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
def parameters(self, *args, **kwargs):
yield from self.embedder.parameters(*args, **kwargs)
def named_parameters(self, *args, **kwargs):
yield from self.embedder.named_parameters(*args, **kwargs)
def get_clip_image_embeds_from_tensors(
self, tensors_0_1: torch.Tensor, drop=False,
is_training=False
) -> torch.Tensor:
with torch.no_grad():
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop:
clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_image_embeds = clip_output.hidden_states[-2]
return clip_image_embeds
import torch
def set_vec(self, new_vector, text_encoder_idx=0):
# Get the embedding layer
embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings()
# Indices to replace in the embeddings
indices_to_replace = self.placeholder_token_ids[text_encoder_idx]
# Replace the specified embeddings with new_vector
for idx in indices_to_replace:
vector_idx = idx - indices_to_replace[0]
embedding_layer.weight[idx] = new_vector[vector_idx]
# adds it to the tokenizer
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
image_prompt_embeds = self.embedder(clip_image_embeds)
# todo add support for multiple batch sizes
if image_prompt_embeds.shape[0] != 1:
raise ValueError("Batch size must be 1 for embedder for now")
# output on sd1.5 is bs, num_tokens, 768
if len(self.text_encoder_list) == 1:
# add it to the text encoder
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
else:
raise ValueError("Multiple text encoders not supported yet")
# just a place to put a breakpoint
pass
def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(
self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids
):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
text_encoder.get_input_embeddings().weight[
index_no_updates
] = orig_embeds[index_no_updates]
# detach it all
text_encoder.get_input_embeddings().weight.detach_()
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True
def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = ["[name]", "[trigger]"]
replace_with = embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
# reverses injection with class name. useful for normalizations
def inject_trigger_class_name_into_prompt(self, prompt):
output_prompt = prompt
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger]
replace_with = self.config.trigger_class_name
to_replace_list = default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances > 1:
print(
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt

View File

@@ -14,6 +14,7 @@ SaveFormat = Literal['safetensors', 'diffusers']
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.guidance import GuidanceType from toolkit.guidance import GuidanceType
class SaveConfig: class SaveConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000) self.save_every: int = kwargs.get('save_every', 1000)
@@ -47,7 +48,8 @@ class SampleConfig:
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg') self.ext: ImgExt = kwargs.get('format', 'jpg')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0) self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists self.refiner_start_at = kwargs.get('refiner_start_at',
0.5) # step to start using refiner on sample if it exists
class LormModuleSettingsConfig: class LormModuleSettingsConfig:
@@ -130,7 +132,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+']
class AdapterConfig: class AdapterConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip
self.in_channels: int = kwargs.get('in_channels', 3) self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
@@ -153,15 +155,9 @@ class AdapterConfig:
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
# clip vision
self.trigger = kwargs.get('trigger', 'tri993r')
class ClipTokenMakerConfig: self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
def __init__(self, **kwargs):
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
self.num_tokens: int = kwargs.get('num_tokens', 8)
class EmbeddingConfig: class EmbeddingConfig:
@@ -401,7 +397,8 @@ class DatasetConfig:
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path', self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black) None) # focus mask (black and white. White has higher loss than black)
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located self.unconditional_path: str = kwargs.get('unconditional_path',
None) # path where matching unconditional images are located
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1 self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi', self.poi: Union[str, None] = kwargs.get('poi',

160
toolkit/resampler.py Normal file
View File

@@ -0,0 +1,160 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0,
# number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

View File

@@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint
from tqdm import tqdm from tqdm import tqdm
from torchvision.transforms import Resize, transforms from torchvision.transforms import Resize, transforms
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae convert_vae_state_dict, load_vae
@@ -472,7 +473,7 @@ class StableDiffusion:
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image extra['image'] = validation_image
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter): if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([ transform = transforms.Compose([
transforms.ToTensor(), transforms.ToTensor(),
]) ])
@@ -483,6 +484,12 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed) torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed)
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
and gen_config.adapter_image_path is not None:
# run through the adapter to saturate the embeds
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
# encode the prompt ourselves so we can do fun stuff with embeddings # encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
@@ -496,8 +503,8 @@ class StableDiffusion:
unconditional_embeds, unconditional_embeds,
) )
if self.adapter is not None and isinstance(self.adapter, if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
IPAdapter) and gen_config.adapter_image_path is not None: and gen_config.adapter_image_path is not None:
# apply the image projection # apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
@@ -1445,6 +1452,9 @@ class StableDiffusion:
elif isinstance(self.adapter, T2IAdapter): elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device adapter_device = self.adapter.device
elif isinstance(self.adapter, ClipVisionAdapter):
requires_grad = self.adapter.embedder.training
adapter_device = self.adapter.device
else: else:
raise ValueError(f"Unknown adapter type: {type(self.adapter)}") raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
self.device_state['adapter'] = { self.device_state['adapter'] = {