mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
WIP on clip vision encoder
This commit is contained in:
@@ -4,6 +4,8 @@ import torch
|
||||
import weakref
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -24,11 +26,11 @@ import torch.nn as nn
|
||||
class Embedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_input_tokens: int = 50,
|
||||
num_input_tokens: int = 1,
|
||||
input_dim: int = 1024,
|
||||
num_output_tokens: int = 8,
|
||||
output_dim: int = 768,
|
||||
mid_dim: int = 128,
|
||||
mid_dim: int = 1024
|
||||
):
|
||||
super(Embedder, self).__init__()
|
||||
self.num_output_tokens = num_output_tokens
|
||||
@@ -36,29 +38,24 @@ class Embedder(nn.Module):
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Convolutional layer to reduce channel dimension
|
||||
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
|
||||
|
||||
# GELU Activation
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
self.fc1 = nn.Linear(input_dim, mid_dim)
|
||||
self.gelu = nn.GELU()
|
||||
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
||||
|
||||
# Layer Normalization
|
||||
self.layer_norm = nn.LayerNorm(mid_dim)
|
||||
|
||||
# Adaptive pooling to change sequence length
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
|
||||
|
||||
# Linear layer for final transformation
|
||||
self.final_linear = nn.Linear(mid_dim, output_dim)
|
||||
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # Adjust for Conv1d
|
||||
x = self.conv(x)
|
||||
x = self.layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
|
||||
x = self.adaptive_pool(x)
|
||||
x = x.permute(0, 2, 1) # Adjust back
|
||||
x = self.final_linear(x)
|
||||
x = self.fc2(x)
|
||||
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
||||
|
||||
# repeat the static tokens for each batch
|
||||
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
|
||||
x = static_tokens + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -140,24 +137,29 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
self.image_encoder.train()
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
# self.embedder = Embedder(
|
||||
# num_output_tokens=self.config.num_tokens,
|
||||
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
|
||||
# input_dim=self.image_encoder.config.hidden_size,
|
||||
# output_dim=sd.unet.config['cross_attention_dim'],
|
||||
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
dim = sd.unet.config['cross_attention_dim']
|
||||
self.embedder = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
in_tokens = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
else:
|
||||
embedding_dim = self.image_encoder.config.hidden_size
|
||||
|
||||
if self.config.clip_layer == 'image_embeds':
|
||||
in_tokens = 1
|
||||
embedding_dim = self.image_encoder.config.projection_dim
|
||||
|
||||
self.embedder = Embedder(
|
||||
num_output_tokens=self.config.num_tokens,
|
||||
num_input_tokens=in_tokens,
|
||||
input_dim=embedding_dim,
|
||||
output_dim=self.sd_ref().unet.config['cross_attention_dim'],
|
||||
mid_dim=embedding_dim * self.config.num_tokens,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
|
||||
self.embedder.train()
|
||||
@@ -186,37 +188,76 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self, tensors_0_1: torch.Tensor, drop=False,
|
||||
is_training=False
|
||||
is_training=False,
|
||||
has_been_preprocessed=False
|
||||
) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
# unconditional
|
||||
if drop:
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
if drop:
|
||||
# scale the noise down
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
std = torch.tensor(self.clip_image_processor.image_std).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
|
||||
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||
# they skip last layer for ip+
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
return clip_image_embeds
|
||||
|
||||
import torch
|
||||
@@ -236,6 +277,9 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
# adds it to the tokenizer
|
||||
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if clip_image_embeds.ndim == 2:
|
||||
# expand the token dimension
|
||||
clip_image_embeds = clip_image_embeds.unsqueeze(1)
|
||||
image_prompt_embeds = self.embedder(clip_image_embeds)
|
||||
# todo add support for multiple batch sizes
|
||||
if image_prompt_embeds.shape[0] != 1:
|
||||
|
||||
Reference in New Issue
Block a user