WIP on clip vision encoder

This commit is contained in:
Jaret Burkett
2024-03-13 07:24:08 -06:00
parent d87b49882c
commit 72de68d8aa
4 changed files with 164 additions and 73 deletions

View File

@@ -740,14 +740,36 @@ class SDTrainer(BaseSDTrainProcess):
embeds_to_use = conditional_embeds.clone().detach()
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
prompt_list = batch.get_caption_list()
class_name = ''
triggers = ['[trigger]', '[name]']
remove_tokens = []
if self.embed_config is not None:
triggers.append(self.embed_config.trigger)
for i in range(1, self.embed_config.tokens):
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
if self.embed_config.trigger_class_name is not None:
class_name = self.embed_config.trigger_class_name
if self.adapter is not None:
triggers.append(self.adapter_config.trigger)
for i in range(1, self.adapter_config.num_tokens):
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
if self.adapter_config.trigger_class_name is not None:
class_name = self.adapter_config.trigger_class_name
for idx, prompt in enumerate(prompt_list):
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
for remove_token in remove_tokens:
prompt = prompt.replace(remove_token, '')
for trigger in triggers:
prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt,
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
@@ -1030,7 +1052,8 @@ class SDTrainer(BaseSDTrainProcess):
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True
is_training=True,
has_been_preprocessed=True
)
else:
# just do a blank one
@@ -1039,7 +1062,9 @@ class SDTrainer(BaseSDTrainProcess):
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
),
is_training=True
is_training=True,
has_been_preprocessed=True,
drop=True
)
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)

View File

@@ -4,6 +4,8 @@ import torch
import weakref
from toolkit.config_modules import AdapterConfig
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -24,11 +26,11 @@ import torch.nn as nn
class Embedder(nn.Module):
def __init__(
self,
num_input_tokens: int = 50,
num_input_tokens: int = 1,
input_dim: int = 1024,
num_output_tokens: int = 8,
output_dim: int = 768,
mid_dim: int = 128,
mid_dim: int = 1024
):
super(Embedder, self).__init__()
self.num_output_tokens = num_output_tokens
@@ -36,29 +38,24 @@ class Embedder(nn.Module):
self.input_dim = input_dim
self.output_dim = output_dim
# Convolutional layer to reduce channel dimension
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
# GELU Activation
self.layer_norm = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, mid_dim)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# Layer Normalization
self.layer_norm = nn.LayerNorm(mid_dim)
# Adaptive pooling to change sequence length
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
# Linear layer for final transformation
self.final_linear = nn.Linear(mid_dim, output_dim)
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
def forward(self, x):
x = x.permute(0, 2, 1) # Adjust for Conv1d
x = self.conv(x)
x = self.layer_norm(x)
x = self.fc1(x)
x = self.gelu(x)
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
x = self.adaptive_pool(x)
x = x.permute(0, 2, 1) # Adjust back
x = self.final_linear(x)
x = self.fc2(x)
x = x.view(-1, self.num_output_tokens, self.output_dim)
# repeat the static tokens for each batch
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
x = static_tokens + x
return x
@@ -140,24 +137,29 @@ class ClipVisionAdapter(torch.nn.Module):
self.image_encoder.train()
else:
self.image_encoder.eval()
# self.embedder = Embedder(
# num_output_tokens=self.config.num_tokens,
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
# input_dim=self.image_encoder.config.hidden_size,
# output_dim=sd.unet.config['cross_attention_dim'],
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
heads = 12 if not sd.is_xl else 20
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
dim = sd.unet.config['cross_attention_dim']
self.embedder = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
# max_seq_len = CLIP tokens + CLS token
image_encoder_state_dict = self.image_encoder.state_dict()
in_tokens = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
if hasattr(self.image_encoder.config, 'hidden_sizes'):
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
else:
embedding_dim = self.image_encoder.config.hidden_size
if self.config.clip_layer == 'image_embeds':
in_tokens = 1
embedding_dim = self.image_encoder.config.projection_dim
self.embedder = Embedder(
num_output_tokens=self.config.num_tokens,
num_input_tokens=in_tokens,
input_dim=embedding_dim,
output_dim=self.sd_ref().unet.config['cross_attention_dim'],
mid_dim=embedding_dim * self.config.num_tokens,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
self.embedder.train()
@@ -186,37 +188,76 @@ class ClipVisionAdapter(torch.nn.Module):
def get_clip_image_embeds_from_tensors(
self, tensors_0_1: torch.Tensor, drop=False,
is_training=False
is_training=False,
has_been_preprocessed=False
) -> torch.Tensor:
with torch.no_grad():
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop:
clip_image = clip_image * 0
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
# unconditional
if drop:
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
if drop:
# scale the noise down
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
mean = torch.tensor(self.clip_image_processor.image_mean).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
std = torch.tensor(self.clip_image_processor.image_std).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_image_embeds = clip_output.hidden_states[-2]
if self.config.clip_layer == 'penultimate_hidden_states':
# they skip last layer for ip+
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
clip_image_embeds = clip_output.hidden_states[-2]
elif self.config.clip_layer == 'last_hidden_state':
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
return clip_image_embeds
import torch
@@ -236,6 +277,9 @@ class ClipVisionAdapter(torch.nn.Module):
# adds it to the tokenizer
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if clip_image_embeds.ndim == 2:
# expand the token dimension
clip_image_embeds = clip_image_embeds.unsqueeze(1)
image_prompt_embeds = self.embedder(clip_image_embeds)
# todo add support for multiple batch sizes
if image_prompt_embeds.shape[0] != 1:

View File

@@ -166,7 +166,7 @@ class AdapterConfig:
# clip vision
self.trigger = kwargs.get('trigger', 'tri993r')
self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
self.trigger_class_name = kwargs.get('trigger_class_name', None)
self.class_names = kwargs.get('class_names', [])
@@ -188,6 +188,7 @@ class EmbeddingConfig:
self.tokens = kwargs.get('tokens', 4)
self.init_words = kwargs.get('init_words', '*')
self.save_format = kwargs.get('save_format', 'safetensors')
self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior
ContentOrStyleType = Literal['balanced', 'style', 'content']

View File

@@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
StableDiffusionXLImg2ImgPipeline, LCMScheduler
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel
import diffusers
from diffusers import \
AutoencoderKL, \
@@ -242,10 +242,21 @@ class StableDiffusion:
device_map="auto",
torch_dtype=self.torch_dtype,
)
# load the transformer
subfolder = "transformer"
# check if it is just the unet
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
subfolder = None
# load the transformer only from the save
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder)
# replace the to function with a no-op since it throws an error instead of a warning
text_encoder.to = lambda *args, **kwargs: None
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
model_path,
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
text_encoder=text_encoder,
dtype=dtype,
device=self.device_torch,
@@ -1081,10 +1092,14 @@ class StableDiffusion:
else:
noise_pred = noise_pred
else:
if self.unet.device != self.device_torch:
self.unet.to(self.device_torch)
if self.unet.dtype != self.torch_dtype:
self.unet = self.unet.to(dtype=self.torch_dtype)
noise_pred = self.unet(
latent_model_input.to(self.device_torch, self.torch_dtype),
timestep,
encoder_hidden_states=text_embeddings.text_embeds,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
**kwargs,
).sample
@@ -1485,10 +1500,16 @@ class StableDiffusion:
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
if self.is_pixart:
self.unet.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f: