Added a clip vision adapter trainer. Only works for sd15 for now

This commit is contained in:
Jaret Burkett
2023-12-24 13:26:04 -07:00
parent 0f8daa5612
commit 05ae95ca89
6 changed files with 586 additions and 20 deletions

View File

@@ -5,6 +5,7 @@ from diffusers import T2IAdapter
import torch.functional as F
from toolkit import train_tools
from toolkit.basic import value_map, adain, get_mean_std
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.guidance import get_targeted_guidance_loss, get_guidance_loss
@@ -504,6 +505,7 @@ class SDTrainer(BaseSDTrainProcess):
noise: torch.Tensor,
**kwargs
):
# todo for embeddings, we need to run without trigger words
was_unet_training = self.sd.unet.training
was_network_active = False
if self.network is not None:
@@ -519,13 +521,28 @@ class SDTrainer(BaseSDTrainProcess):
# do a prediction here so we can match its output with network multiplier set to 0.0
with torch.no_grad():
dtype = get_torch_dtype(self.train_config.dtype)
embeds_to_use = conditional_embeds.clone().detach()
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
prompt_list = batch.get_caption_list()
for idx, prompt in enumerate(prompt_list):
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
# dont use network on this
# self.network.multiplier = 0.0
self.sd.unet.eval()
prior_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype).detach(),
conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(),
timestep=timesteps,
guidance_scale=1.0,
**pred_kwargs # adapter residuals in here
@@ -666,6 +683,9 @@ class SDTrainer(BaseSDTrainProcess):
if self.embedding:
grad_on_text_encoder = True
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
@@ -745,6 +765,26 @@ class SDTrainer(BaseSDTrainProcess):
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network:
# encode clip adapter here so embeds are active for tokenizer
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('encode_clip_vision_embeds'):
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True
)
else:
# just do a blank one
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
),
is_training=True
)
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)
with self.timer('encode_prompt'):
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
@@ -912,6 +952,10 @@ class SDTrainer(BaseSDTrainProcess):
with self.timer('restore_embeddings'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.embedding.restore_embeddings()
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
with self.timer('restore_adapter'):
# Let's make sure we don't update any embedding weights besides the newly added token
self.adapter.restore_embeddings()
loss_dict = OrderedDict(
{'loss': loss.item()}

View File

@@ -5,7 +5,7 @@ import json
import shutil
from collections import OrderedDict
import os
from typing import Union, List
from typing import Union, List, Optional
import numpy as np
import yaml
@@ -17,6 +17,7 @@ import torch
import torch.backends.cuda
from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
@@ -138,7 +139,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, ClipVisionAdapter, None] = None
self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
@@ -202,7 +203,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# ie test123 will become test123 test123_1 test123_2 etc. Do not add this yourself here
if self.embedding is not None:
prompt = self.embedding.inject_embedding_to_prompt(
prompt, add_if_not_present=False
prompt, expand_token=True, add_if_not_present=False
)
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
prompt = self.adapter.inject_trigger_into_prompt(
prompt, expand_token=True, add_if_not_present=False
)
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
@@ -400,6 +405,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
# add _lora to name
if self.adapter_config.type == 't2i':
adapter_name += '_t2i'
elif self.adapter_config.type == 'clip':
adapter_name += '_clip'
else:
adapter_name += '_ip'
@@ -647,6 +654,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
add_if_not_present=not is_reg,
)
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
prompt = self.adapter.inject_trigger_into_prompt(
prompt,
expand_token=True,
add_if_not_present=not is_reg,
)
# make sure trigger is in the prompts if not a regularization run
if self.trigger_word is not None:
prompt = self.sd.inject_trigger_into_prompt(
@@ -840,7 +854,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
def setup_adapter(self):
# t2i adapter
is_t2i = self.adapter_config.type == 't2i'
suffix = 't2i' if is_t2i else 'ip'
if self.adapter_config.type == 't2i':
suffix = 't2i'
elif self.adapter_config.type == 'clip':
suffix = 'clip'
else:
suffix = 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
@@ -865,6 +884,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
elif self.adapter_config.type == 'clip':
self.adapter = ClipVisionAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
else:
self.adapter = IPAdapter(
sd=self.sd,

View File

@@ -0,0 +1,331 @@
from typing import TYPE_CHECKING, Mapping, Any
import torch
import weakref
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPVisionModel
)
from toolkit.resampler import Resampler
import torch.nn as nn
class Embedder(nn.Module):
def __init__(
self,
num_input_tokens: int = 50,
input_dim: int = 1024,
num_output_tokens: int = 8,
output_dim: int = 768,
mid_dim: int = 128,
):
super(Embedder, self).__init__()
self.num_output_tokens = num_output_tokens
self.num_input_tokens = num_input_tokens
self.input_dim = input_dim
self.output_dim = output_dim
# Convolutional layer to reduce channel dimension
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
# GELU Activation
self.gelu = nn.GELU()
# Layer Normalization
self.layer_norm = nn.LayerNorm(mid_dim)
# Adaptive pooling to change sequence length
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
# Linear layer for final transformation
self.final_linear = nn.Linear(mid_dim, output_dim)
def forward(self, x):
x = x.permute(0, 2, 1) # Adjust for Conv1d
x = self.conv(x)
x = self.gelu(x)
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
x = self.adaptive_pool(x)
x = x.permute(0, 2, 1) # Adjust back
x = self.final_linear(x)
return x
class ClipVisionAdapter(torch.nn.Module):
def __init__(self, sd: 'StableDiffusion', adapter_config: AdapterConfig):
super().__init__()
self.config = adapter_config
self.trigger = adapter_config.trigger
self.trigger_class_name = adapter_config.trigger_class_name
self.sd_ref: weakref.ref = weakref.ref(sd)
# embedding stuff
self.text_encoder_list = sd.text_encoder if isinstance(sd.text_encoder, list) else [sd.text_encoder]
self.tokenizer_list = sd.tokenizer if isinstance(sd.tokenizer, list) else [sd.tokenizer]
placeholder_tokens = [self.trigger]
# add dummy tokens for multi-vector
additional_tokens = []
for i in range(1, self.config.num_tokens):
additional_tokens.append(f"{self.trigger}_{i}")
placeholder_tokens += additional_tokens
# handle dual tokenizer
self.tokenizer_list = self.sd_ref().tokenizer if isinstance(self.sd_ref().tokenizer, list) else [
self.sd_ref().tokenizer]
self.text_encoder_list = self.sd_ref().text_encoder if isinstance(self.sd_ref().text_encoder, list) else [
self.sd_ref().text_encoder]
self.placeholder_token_ids = []
self.embedding_tokens = []
print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.config.num_tokens:
raise ValueError(
f"The tokenizer already contains the token {self.trigger}. Please pass a different"
f" `placeholder_token` that is not already in the tokenizer. Only added {num_added_tokens}"
)
# Convert the initializer_token, placeholder_token to ids
init_token_ids = tokenizer.encode(self.config.trigger_class_name, add_special_tokens=False)
# if length of token ids is more than number of orm embedding tokens fill with *
if len(init_token_ids) > self.config.num_tokens:
init_token_ids = init_token_ids[:self.config.num_tokens]
elif len(init_token_ids) < self.config.num_tokens:
pad_token_id = tokenizer.encode(["*"], add_special_tokens=False)
init_token_ids += pad_token_id * (self.config.num_tokens - len(init_token_ids))
placeholder_token_ids = tokenizer.encode(placeholder_tokens, add_special_tokens=False)
self.placeholder_token_ids.append(placeholder_token_ids)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder.resize_token_embeddings(len(tokenizer))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for initializer_token_id, token_id in zip(init_token_ids, placeholder_token_ids):
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
# replace "[name] with this. on training. This is automatically generated in pipeline on inference
self.embedding_tokens.append(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids)))
# backup text encoder embeddings
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(self.config.image_encoder_path)
except EnvironmentError:
self.clip_image_processor = CLIPImageProcessor()
self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
self.config.image_encoder_path,
ignore_mismatched_sizes=True
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if self.config.train_image_encoder:
self.image_encoder.train()
else:
self.image_encoder.eval()
# self.embedder = Embedder(
# num_output_tokens=self.config.num_tokens,
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
# input_dim=self.image_encoder.config.hidden_size,
# output_dim=sd.unet.config['cross_attention_dim'],
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
heads = 12 if not sd.is_xl else 20
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
dim = sd.unet.config['cross_attention_dim']
self.embedder = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
self.embedder.train()
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
state_dict = {
'embedder': self.embedder.state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
}
if self.config.train_image_encoder:
state_dict['image_encoder'] = self.image_encoder.state_dict(
*args, destination=destination, prefix=prefix,
keep_vars=keep_vars)
return state_dict
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.embedder.load_state_dict(state_dict["embedder"], strict=strict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
def parameters(self, *args, **kwargs):
yield from self.embedder.parameters(*args, **kwargs)
def named_parameters(self, *args, **kwargs):
yield from self.embedder.named_parameters(*args, **kwargs)
def get_clip_image_embeds_from_tensors(
self, tensors_0_1: torch.Tensor, drop=False,
is_training=False
) -> torch.Tensor:
with torch.no_grad():
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop:
clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_image_embeds = clip_output.hidden_states[-2]
return clip_image_embeds
import torch
def set_vec(self, new_vector, text_encoder_idx=0):
# Get the embedding layer
embedding_layer = self.text_encoder_list[text_encoder_idx].get_input_embeddings()
# Indices to replace in the embeddings
indices_to_replace = self.placeholder_token_ids[text_encoder_idx]
# Replace the specified embeddings with new_vector
for idx in indices_to_replace:
vector_idx = idx - indices_to_replace[0]
embedding_layer.weight[idx] = new_vector[vector_idx]
# adds it to the tokenizer
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
image_prompt_embeds = self.embedder(clip_image_embeds)
# todo add support for multiple batch sizes
if image_prompt_embeds.shape[0] != 1:
raise ValueError("Batch size must be 1 for embedder for now")
# output on sd1.5 is bs, num_tokens, 768
if len(self.text_encoder_list) == 1:
# add it to the text encoder
self.set_vec(image_prompt_embeds[0], text_encoder_idx=0)
else:
raise ValueError("Multiple text encoders not supported yet")
# just a place to put a breakpoint
pass
def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(
self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids
):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
text_encoder.get_input_embeddings().weight[
index_no_updates
] = orig_embeds[index_no_updates]
# detach it all
text_encoder.get_input_embeddings().weight.detach_()
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True
def inject_trigger_into_prompt(self, prompt, expand_token=False, to_replace_list=None, add_if_not_present=True):
output_prompt = prompt
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = ["[name]", "[trigger]"]
replace_with = embedding_tokens if expand_token else self.trigger
if to_replace_list is None:
to_replace_list = default_replacements
else:
to_replace_list += default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances == 0 and add_if_not_present:
# add it to the beginning of the prompt
output_prompt = replace_with + " " + output_prompt
if num_instances > 1:
print(
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt
# reverses injection with class name. useful for normalizations
def inject_trigger_class_name_into_prompt(self, prompt):
output_prompt = prompt
embedding_tokens = self.embedding_tokens[0] # shoudl be the same
default_replacements = ["[name]", "[trigger]", embedding_tokens, self.trigger]
replace_with = self.config.trigger_class_name
to_replace_list = default_replacements
# remove duplicates
to_replace_list = list(set(to_replace_list))
# replace them all
for to_replace in to_replace_list:
# replace it
output_prompt = output_prompt.replace(to_replace, replace_with)
# see how many times replace_with is in the prompt
num_instances = output_prompt.count(replace_with)
if num_instances > 1:
print(
f"Warning: {replace_with} token appears {num_instances} times in prompt {output_prompt}. This may cause issues.")
return output_prompt

View File

@@ -14,6 +14,7 @@ SaveFormat = Literal['safetensors', 'diffusers']
if TYPE_CHECKING:
from toolkit.guidance import GuidanceType
class SaveConfig:
def __init__(self, **kwargs):
self.save_every: int = kwargs.get('save_every', 1000)
@@ -47,7 +48,8 @@ class SampleConfig:
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.ext: ImgExt = kwargs.get('format', 'jpg')
self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
self.refiner_start_at = kwargs.get('refiner_start_at', 0.5) # step to start using refiner on sample if it exists
self.refiner_start_at = kwargs.get('refiner_start_at',
0.5) # step to start using refiner on sample if it exists
class LormModuleSettingsConfig:
@@ -130,7 +132,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip
self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
@@ -153,15 +155,9 @@ class AdapterConfig:
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid
class ClipTokenMakerConfig:
def __init__(self, **kwargs):
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
self.num_tokens: int = kwargs.get('num_tokens', 8)
# clip vision
self.trigger = kwargs.get('trigger', 'tri993r')
self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
class EmbeddingConfig:
@@ -401,7 +397,8 @@ class DatasetConfig:
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)
self.unconditional_path: str = kwargs.get('unconditional_path', None) # path where matching unconditional images are located
self.unconditional_path: str = kwargs.get('unconditional_path',
None) # path where matching unconditional images are located
self.invert_mask: bool = kwargs.get('invert_mask', False) # invert mask
self.mask_min_value: float = kwargs.get('mask_min_value', 0.0) # min value for . 0 - 1
self.poi: Union[str, None] = kwargs.get('poi',

160
toolkit/resampler.py Normal file
View File

@@ -0,0 +1,160 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
# and https://github.com/tencent-ailab/IP-Adapter/blob/9fc189e3fb389cc2b60a7d0c0850e083a716ea6e/ip_adapter/resampler.py
import math
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
max_seq_len: int = 257, # CLIP tokens + CLS token
apply_pos_emb: bool = False,
num_latents_mean_pooled: int = 0,
# number of latents derived from mean pooled representation of the sequence
):
super().__init__()
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.to_latents_from_mean_pooled_seq = (
nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * num_latents_mean_pooled),
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
)
if num_latents_mean_pooled > 0
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x):
if self.pos_emb is not None:
n, device = x.shape[1], x.device
pos_emb = self.pos_emb(torch.arange(n, device=device))
x = x + pos_emb
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
if self.to_latents_from_mean_pooled_seq:
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
latents = torch.cat((meanpooled_latents, latents), dim=-2)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)
def masked_mean(t, *, dim, mask=None):
if mask is None:
return t.mean(dim=dim)
denom = mask.sum(dim=dim, keepdim=True)
mask = rearrange(mask, "b n -> b n 1")
masked_t = t.masked_fill(~mask, 0.0)
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)

View File

@@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
from torchvision.transforms import Resize, transforms
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
@@ -472,7 +473,7 @@ class StableDiffusion:
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter):
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),
])
@@ -483,6 +484,12 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
and gen_config.adapter_image_path is not None:
# run through the adapter to saturate the embeds
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
# encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
@@ -496,8 +503,8 @@ class StableDiffusion:
unconditional_embeds,
)
if self.adapter is not None and isinstance(self.adapter,
IPAdapter) and gen_config.adapter_image_path is not None:
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
and gen_config.adapter_image_path is not None:
# apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
@@ -1445,6 +1452,9 @@ class StableDiffusion:
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device
elif isinstance(self.adapter, ClipVisionAdapter):
requires_grad = self.adapter.embedder.training
adapter_device = self.adapter.device
else:
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
self.device_state['adapter'] = {