mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Work on additional image embedding methods. Finalized zipper resampler. It works amazing
This commit is contained in:
@@ -998,7 +998,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
tensors_0_1=clip_images,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
quad_count=quad_count,
|
||||
)
|
||||
|
||||
with self.timer('encode_prompt'):
|
||||
|
||||
@@ -557,6 +557,21 @@ class CustomAdapter(torch.nn.Module):
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora':
|
||||
if tensors_0_1 is None:
|
||||
# scale the noise down
|
||||
tensors_0_1 = torch.rand([1, 3, self.input_size, self.input_size], device=self.device)
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
std = torch.tensor(self.clip_image_processor.image_std).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
@@ -626,7 +641,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
self.ilora_module(img_embeds)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
|
||||
@@ -9,6 +9,7 @@ from torch.nn.modules.module import T
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.zipper_resampler import ZipperResampler
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
@@ -33,6 +34,7 @@ from transformers import (
|
||||
CLIPVisionModel,
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextV2ForImageClassification,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
)
|
||||
@@ -226,6 +228,20 @@ class IPAdapter(torch.nn.Module):
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'convnextv2':
|
||||
try:
|
||||
self.clip_image_processor = AutoImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.clip_image_processor = ConvNextImageProcessor(
|
||||
size=512,
|
||||
image_mean=[0.485,0.456,0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
|
||||
adapter_config.image_encoder_path,
|
||||
use_safetensors=True,
|
||||
).to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
elif self.config.image_encoder_arch == 'vit-hybrid':
|
||||
try:
|
||||
self.clip_image_processor = ViTHybridImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -275,8 +291,12 @@ class IPAdapter(torch.nn.Module):
|
||||
)
|
||||
if 'height' in self.clip_image_processor.size:
|
||||
self.input_size = self.clip_image_processor.size['height']
|
||||
else:
|
||||
elif hasattr(self.clip_image_processor, 'crop_size'):
|
||||
self.input_size = self.clip_image_processor.crop_size['height']
|
||||
elif 'shortest_edge' in self.clip_image_processor.size.keys():
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
else:
|
||||
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
if adapter_config.type == 'ip':
|
||||
@@ -311,6 +331,39 @@ class IPAdapter(torch.nn.Module):
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
elif adapter_config.type == 'ipz':
|
||||
dim = sd.unet.config['cross_attention_dim']
|
||||
if hasattr(self.image_encoder.config, 'hidden_sizes'):
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
else:
|
||||
embedding_dim = self.image_encoder.config.hidden_size
|
||||
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
in_tokens = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
in_tokens = 16 * 16
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
is_conv_next = self.config.image_encoder_arch.startswith('convnext')
|
||||
|
||||
out_tokens = self.config.num_tokens if self.config.num_tokens > 0 else in_tokens
|
||||
# ip-adapter-plus
|
||||
image_proj_model = ZipperResampler(
|
||||
in_size=embedding_dim,
|
||||
in_tokens=in_tokens,
|
||||
out_size=dim,
|
||||
out_tokens=out_tokens,
|
||||
hidden_size=embedding_dim,
|
||||
hidden_tokens=in_tokens,
|
||||
# num_blocks=1 if not is_conv_next else 2,
|
||||
num_blocks=1 if not is_conv_next else 2,
|
||||
is_conv_input=is_conv_next
|
||||
)
|
||||
elif adapter_config.type == 'ilora':
|
||||
# we apply the clip encodings to the LoRA
|
||||
image_proj_model = None
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from toolkit.models.zipper_resampler import ContextualAlphaMask
|
||||
|
||||
|
||||
# Conv1d MLP
|
||||
# MLP that can alternately be used as a conv1d on dim 1
|
||||
@@ -86,46 +88,7 @@ class ZipperBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ContextualAlphaMask(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 768,
|
||||
):
|
||||
super(ContextualAlphaMask, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
half_dim = dim // 2
|
||||
quarter_dim = dim // 4
|
||||
|
||||
self.fc1 = nn.Linear(self.dim, self.dim)
|
||||
self.fc2 = nn.Linear(self.dim, half_dim)
|
||||
self.norm1 = nn.LayerNorm(half_dim)
|
||||
self.fc3 = nn.Linear(half_dim, half_dim)
|
||||
self.fc4 = nn.Linear(half_dim, quarter_dim)
|
||||
self.norm2 = nn.LayerNorm(quarter_dim)
|
||||
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
|
||||
self.fc6 = nn.Linear(quarter_dim, 1)
|
||||
# set fc6 weights to near zero
|
||||
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
# x = (batch_size, 77, 768)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.norm1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc3(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc4(x)
|
||||
x = self.norm2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc5(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc6(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
@@ -26,7 +27,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.zip = ZipperBlock(
|
||||
self.zip = ZipperModule(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
@@ -71,7 +72,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
sd: 'StableDiffusion'
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
self.linear = torch.nn.Linear(2, 1)
|
||||
# self.linear = torch.nn.Linear(2, 1)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.dim = sd.network.lora_dim
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
@@ -83,6 +84,15 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# disable merging in. It is slower on inference
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
self.resampler = ZipperResampler(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.vision_hidden_size,
|
||||
out_tokens=self.vision_tokens,
|
||||
hidden_size=self.vision_hidden_size,
|
||||
hidden_tokens=self.vision_tokens
|
||||
)
|
||||
|
||||
self.ilora_modules = torch.nn.ModuleList()
|
||||
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
@@ -99,5 +109,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
def forward(self, img_embeds):
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
self.img_embeds = img_embeds
|
||||
|
||||
|
||||
171
toolkit/models/zipper_resampler.py
Normal file
171
toolkit/models/zipper_resampler.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ContextualAlphaMask(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 768,
|
||||
):
|
||||
super(ContextualAlphaMask, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
half_dim = dim // 2
|
||||
quarter_dim = dim // 4
|
||||
|
||||
self.fc1 = nn.Linear(self.dim, self.dim)
|
||||
self.fc2 = nn.Linear(self.dim, half_dim)
|
||||
self.norm1 = nn.LayerNorm(half_dim)
|
||||
self.fc3 = nn.Linear(half_dim, half_dim)
|
||||
self.fc4 = nn.Linear(half_dim, quarter_dim)
|
||||
self.norm2 = nn.LayerNorm(quarter_dim)
|
||||
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
|
||||
self.fc6 = nn.Linear(quarter_dim, 1)
|
||||
# set fc6 weights to near zero
|
||||
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
# x = (batch_size, 77, 768)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.norm1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc3(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc4(x)
|
||||
x = self.norm2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc5(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc6(x)
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class ZipperModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
in_tokens,
|
||||
out_size,
|
||||
out_tokens,
|
||||
hidden_size,
|
||||
hidden_tokens,
|
||||
use_residual=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_size = in_size
|
||||
self.in_tokens = in_tokens
|
||||
self.out_size = out_size
|
||||
self.out_tokens = out_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_tokens = hidden_tokens
|
||||
self.use_residual = use_residual
|
||||
|
||||
self.act_fn = nn.GELU()
|
||||
self.layernorm = nn.LayerNorm(self.in_size)
|
||||
|
||||
self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1)
|
||||
# act
|
||||
self.fc1 = nn.Linear(self.in_size, self.hidden_size)
|
||||
# act
|
||||
self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1)
|
||||
# act
|
||||
self.fc2 = nn.Linear(self.hidden_size, self.out_size)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.conv1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
|
||||
class ZipperResampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
in_tokens,
|
||||
out_size,
|
||||
out_tokens,
|
||||
hidden_size,
|
||||
hidden_tokens,
|
||||
num_blocks=1,
|
||||
is_conv_input=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.is_conv_input = is_conv_input
|
||||
|
||||
module_list = []
|
||||
for i in range(num_blocks):
|
||||
|
||||
this_in_size = in_size
|
||||
this_in_tokens = in_tokens
|
||||
this_out_size = out_size
|
||||
this_out_tokens = out_tokens
|
||||
this_hidden_size = hidden_size
|
||||
this_hidden_tokens = hidden_tokens
|
||||
use_residual = False
|
||||
|
||||
# maintain middle sizes as hidden_size
|
||||
if i == 0: # first block
|
||||
this_in_size = in_size
|
||||
this_in_tokens = in_tokens
|
||||
if num_blocks == 1:
|
||||
this_out_size = out_size
|
||||
this_out_tokens = out_tokens
|
||||
else:
|
||||
this_out_size = hidden_size
|
||||
this_out_tokens = hidden_tokens
|
||||
elif i == num_blocks - 1: # last block
|
||||
this_out_size = out_size
|
||||
this_out_tokens = out_tokens
|
||||
if num_blocks == 1:
|
||||
this_in_size = in_size
|
||||
this_in_tokens = in_tokens
|
||||
else:
|
||||
this_in_size = hidden_size
|
||||
this_in_tokens = hidden_tokens
|
||||
else: # middle blocks
|
||||
this_out_size = hidden_size
|
||||
this_out_tokens = hidden_tokens
|
||||
this_in_size = hidden_size
|
||||
this_in_tokens = hidden_tokens
|
||||
use_residual = True
|
||||
|
||||
module_list.append(ZipperModule(
|
||||
in_size=this_in_size,
|
||||
in_tokens=this_in_tokens,
|
||||
out_size=this_out_size,
|
||||
out_tokens=this_out_tokens,
|
||||
hidden_size=this_hidden_size,
|
||||
hidden_tokens=this_hidden_tokens,
|
||||
use_residual=use_residual
|
||||
))
|
||||
|
||||
self.blocks = nn.ModuleList(module_list)
|
||||
|
||||
self.ctx_alpha = ContextualAlphaMask(
|
||||
dim=out_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_conv_input:
|
||||
# flatten
|
||||
x = x.view(x.size(0), x.size(1), -1)
|
||||
# rearrange to (batch, tokens, size)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
alpha = self.ctx_alpha(x)
|
||||
return x * alpha
|
||||
@@ -473,7 +473,7 @@ class ToolkitNetworkMixin:
|
||||
del load_sd[key]
|
||||
|
||||
print(f"Missing keys: {to_delete}")
|
||||
if len(to_delete) > 0 and self.is_v1:
|
||||
if len(to_delete) > 0 and self.is_v1 and not force_weight_mapping and not (len(to_delete) == 1 and 'emb_params' in to_delete):
|
||||
print(" Attempting to load with forced keymap")
|
||||
return self.load_weights(file, force_weight_mapping=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user