WIP on clip vision encoder

This commit is contained in:
Jaret Burkett
2024-03-13 07:24:08 -06:00
parent d87b49882c
commit 72de68d8aa
4 changed files with 164 additions and 73 deletions

View File

@@ -4,6 +4,8 @@ import torch
import weakref
from toolkit.config_modules import AdapterConfig
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule
from toolkit.prompt_utils import PromptEmbeds
from toolkit.train_tools import get_torch_dtype
@@ -24,11 +26,11 @@ import torch.nn as nn
class Embedder(nn.Module):
def __init__(
self,
num_input_tokens: int = 50,
num_input_tokens: int = 1,
input_dim: int = 1024,
num_output_tokens: int = 8,
output_dim: int = 768,
mid_dim: int = 128,
mid_dim: int = 1024
):
super(Embedder, self).__init__()
self.num_output_tokens = num_output_tokens
@@ -36,29 +38,24 @@ class Embedder(nn.Module):
self.input_dim = input_dim
self.output_dim = output_dim
# Convolutional layer to reduce channel dimension
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
# GELU Activation
self.layer_norm = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, mid_dim)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# Layer Normalization
self.layer_norm = nn.LayerNorm(mid_dim)
# Adaptive pooling to change sequence length
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
# Linear layer for final transformation
self.final_linear = nn.Linear(mid_dim, output_dim)
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
def forward(self, x):
x = x.permute(0, 2, 1) # Adjust for Conv1d
x = self.conv(x)
x = self.layer_norm(x)
x = self.fc1(x)
x = self.gelu(x)
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
x = self.adaptive_pool(x)
x = x.permute(0, 2, 1) # Adjust back
x = self.final_linear(x)
x = self.fc2(x)
x = x.view(-1, self.num_output_tokens, self.output_dim)
# repeat the static tokens for each batch
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
x = static_tokens + x
return x
@@ -140,24 +137,29 @@ class ClipVisionAdapter(torch.nn.Module):
self.image_encoder.train()
else:
self.image_encoder.eval()
# self.embedder = Embedder(
# num_output_tokens=self.config.num_tokens,
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
# input_dim=self.image_encoder.config.hidden_size,
# output_dim=sd.unet.config['cross_attention_dim'],
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
heads = 12 if not sd.is_xl else 20
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
dim = sd.unet.config['cross_attention_dim']
self.embedder = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
# max_seq_len = CLIP tokens + CLS token
image_encoder_state_dict = self.image_encoder.state_dict()
in_tokens = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
if hasattr(self.image_encoder.config, 'hidden_sizes'):
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
else:
embedding_dim = self.image_encoder.config.hidden_size
if self.config.clip_layer == 'image_embeds':
in_tokens = 1
embedding_dim = self.image_encoder.config.projection_dim
self.embedder = Embedder(
num_output_tokens=self.config.num_tokens,
num_input_tokens=in_tokens,
input_dim=embedding_dim,
output_dim=self.sd_ref().unet.config['cross_attention_dim'],
mid_dim=embedding_dim * self.config.num_tokens,
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
self.embedder.train()
@@ -186,37 +188,76 @@ class ClipVisionAdapter(torch.nn.Module):
def get_clip_image_embeds_from_tensors(
self, tensors_0_1: torch.Tensor, drop=False,
is_training=False
is_training=False,
has_been_preprocessed=False
) -> torch.Tensor:
with torch.no_grad():
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop:
clip_image = clip_image * 0
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
# unconditional
if drop:
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
if drop:
# scale the noise down
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
mean = torch.tensor(self.clip_image_processor.image_mean).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
std = torch.tensor(self.clip_image_processor.image_std).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_image_embeds = clip_output.hidden_states[-2]
if self.config.clip_layer == 'penultimate_hidden_states':
# they skip last layer for ip+
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
clip_image_embeds = clip_output.hidden_states[-2]
elif self.config.clip_layer == 'last_hidden_state':
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
return clip_image_embeds
import torch
@@ -236,6 +277,9 @@ class ClipVisionAdapter(torch.nn.Module):
# adds it to the tokenizer
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
if clip_image_embeds.ndim == 2:
# expand the token dimension
clip_image_embeds = clip_image_embeds.unsqueeze(1)
image_prompt_embeds = self.embedder(clip_image_embeds)
# todo add support for multiple batch sizes
if image_prompt_embeds.shape[0] != 1: