mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-07 22:19:57 +00:00
WIP on clip vision encoder
This commit is contained in:
@@ -740,14 +740,36 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
embeds_to_use = conditional_embeds.clone().detach()
|
||||
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
|
||||
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
|
||||
prompt_list = batch.get_caption_list()
|
||||
class_name = ''
|
||||
|
||||
triggers = ['[trigger]', '[name]']
|
||||
remove_tokens = []
|
||||
|
||||
if self.embed_config is not None:
|
||||
triggers.append(self.embed_config.trigger)
|
||||
for i in range(1, self.embed_config.tokens):
|
||||
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
|
||||
if self.embed_config.trigger_class_name is not None:
|
||||
class_name = self.embed_config.trigger_class_name
|
||||
|
||||
if self.adapter is not None:
|
||||
triggers.append(self.adapter_config.trigger)
|
||||
for i in range(1, self.adapter_config.num_tokens):
|
||||
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
|
||||
if self.adapter_config.trigger_class_name is not None:
|
||||
class_name = self.adapter_config.trigger_class_name
|
||||
|
||||
for idx, prompt in enumerate(prompt_list):
|
||||
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
|
||||
for remove_token in remove_tokens:
|
||||
prompt = prompt.replace(remove_token, '')
|
||||
for trigger in triggers:
|
||||
prompt = prompt.replace(trigger, class_name)
|
||||
prompt_list[idx] = prompt
|
||||
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt,
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
@@ -1030,7 +1052,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
is_training=True,
|
||||
has_been_preprocessed=True
|
||||
)
|
||||
else:
|
||||
# just do a blank one
|
||||
@@ -1039,7 +1062,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
),
|
||||
is_training=True
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
drop=True
|
||||
)
|
||||
# it will be injected into the tokenizer when called
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
@@ -4,6 +4,8 @@ import torch
|
||||
import weakref
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
@@ -24,11 +26,11 @@ import torch.nn as nn
|
||||
class Embedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_input_tokens: int = 50,
|
||||
num_input_tokens: int = 1,
|
||||
input_dim: int = 1024,
|
||||
num_output_tokens: int = 8,
|
||||
output_dim: int = 768,
|
||||
mid_dim: int = 128,
|
||||
mid_dim: int = 1024
|
||||
):
|
||||
super(Embedder, self).__init__()
|
||||
self.num_output_tokens = num_output_tokens
|
||||
@@ -36,29 +38,24 @@ class Embedder(nn.Module):
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
# Convolutional layer to reduce channel dimension
|
||||
self.conv = nn.Conv1d(in_channels=input_dim, out_channels=mid_dim, kernel_size=1)
|
||||
|
||||
# GELU Activation
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
self.fc1 = nn.Linear(input_dim, mid_dim)
|
||||
self.gelu = nn.GELU()
|
||||
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
||||
|
||||
# Layer Normalization
|
||||
self.layer_norm = nn.LayerNorm(mid_dim)
|
||||
|
||||
# Adaptive pooling to change sequence length
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool1d(num_output_tokens)
|
||||
|
||||
# Linear layer for final transformation
|
||||
self.final_linear = nn.Linear(mid_dim, output_dim)
|
||||
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # Adjust for Conv1d
|
||||
x = self.conv(x)
|
||||
x = self.layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.layer_norm(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply LayerNorm
|
||||
x = self.adaptive_pool(x)
|
||||
x = x.permute(0, 2, 1) # Adjust back
|
||||
x = self.final_linear(x)
|
||||
x = self.fc2(x)
|
||||
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
||||
|
||||
# repeat the static tokens for each batch
|
||||
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
|
||||
x = static_tokens + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -140,24 +137,29 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
self.image_encoder.train()
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
# self.embedder = Embedder(
|
||||
# num_output_tokens=self.config.num_tokens,
|
||||
# num_input_tokens=self.image_encoder.config.top_k, # max_position_embeddings ?
|
||||
# input_dim=self.image_encoder.config.hidden_size,
|
||||
# output_dim=sd.unet.config['cross_attention_dim'],
|
||||
# ).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
# dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
dim = sd.unet.config['cross_attention_dim']
|
||||
self.embedder = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=self.config.num_tokens, # usually 16
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
in_tokens = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
else:
|
||||
embedding_dim = self.image_encoder.config.hidden_size
|
||||
|
||||
if self.config.clip_layer == 'image_embeds':
|
||||
in_tokens = 1
|
||||
embedding_dim = self.image_encoder.config.projection_dim
|
||||
|
||||
self.embedder = Embedder(
|
||||
num_output_tokens=self.config.num_tokens,
|
||||
num_input_tokens=in_tokens,
|
||||
input_dim=embedding_dim,
|
||||
output_dim=self.sd_ref().unet.config['cross_attention_dim'],
|
||||
mid_dim=embedding_dim * self.config.num_tokens,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
|
||||
self.embedder.train()
|
||||
@@ -186,37 +188,76 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self, tensors_0_1: torch.Tensor, drop=False,
|
||||
is_training=False
|
||||
is_training=False,
|
||||
has_been_preprocessed=False
|
||||
) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
# unconditional
|
||||
if drop:
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
if drop:
|
||||
# scale the noise down
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
std = torch.tensor(self.clip_image_processor.image_std).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
else:
|
||||
self.image_encoder.eval()
|
||||
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
|
||||
if self.config.clip_layer == 'penultimate_hidden_states':
|
||||
# they skip last layer for ip+
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/f4b6742db35ea6d81c7b829a55b0a312c7f5a677/tutorial_train_plus.py#L403C26-L403C26
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
return clip_image_embeds
|
||||
|
||||
import torch
|
||||
@@ -236,6 +277,9 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
# adds it to the tokenizer
|
||||
def forward(self, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
if clip_image_embeds.ndim == 2:
|
||||
# expand the token dimension
|
||||
clip_image_embeds = clip_image_embeds.unsqueeze(1)
|
||||
image_prompt_embeds = self.embedder(clip_image_embeds)
|
||||
# todo add support for multiple batch sizes
|
||||
if image_prompt_embeds.shape[0] != 1:
|
||||
|
||||
@@ -166,7 +166,7 @@ class AdapterConfig:
|
||||
|
||||
# clip vision
|
||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||
self.trigger_class_name = kwargs.get('trigger_class_name', 'person')
|
||||
self.trigger_class_name = kwargs.get('trigger_class_name', None)
|
||||
|
||||
self.class_names = kwargs.get('class_names', [])
|
||||
|
||||
@@ -188,6 +188,7 @@ class EmbeddingConfig:
|
||||
self.tokens = kwargs.get('tokens', 4)
|
||||
self.init_words = kwargs.get('init_words', '*')
|
||||
self.save_format = kwargs.get('save_format', 'safetensors')
|
||||
self.trigger_class_name = kwargs.get('trigger_class_name', None) # used for inverted masked prior
|
||||
|
||||
|
||||
ContentOrStyleType = Literal['balanced', 'style', 'content']
|
||||
|
||||
@@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -242,10 +242,21 @@ class StableDiffusion:
|
||||
device_map="auto",
|
||||
torch_dtype=self.torch_dtype,
|
||||
)
|
||||
|
||||
# load the transformer
|
||||
subfolder = "transformer"
|
||||
# check if it is just the unet
|
||||
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
|
||||
subfolder = None
|
||||
# load the transformer only from the save
|
||||
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype, subfolder=subfolder)
|
||||
|
||||
|
||||
# replace the to function with a no-op since it throws an error instead of a warning
|
||||
text_encoder.to = lambda *args, **kwargs: None
|
||||
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
|
||||
model_path,
|
||||
"PixArt-alpha/PixArt-XL-2-1024-MS",
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
dtype=dtype,
|
||||
device=self.device_torch,
|
||||
@@ -1081,10 +1092,14 @@ class StableDiffusion:
|
||||
else:
|
||||
noise_pred = noise_pred
|
||||
else:
|
||||
if self.unet.device != self.device_torch:
|
||||
self.unet.to(self.device_torch)
|
||||
if self.unet.dtype != self.torch_dtype:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
noise_pred = self.unet(
|
||||
latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep,
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
@@ -1485,10 +1500,16 @@ class StableDiffusion:
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
if self.is_pixart:
|
||||
self.unet.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
else:
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
|
||||
Reference in New Issue
Block a user