Work on additional image embedding methods. Finalized zipper resampler. It works amazing

This commit is contained in:
Jaret Burkett
2024-02-10 09:00:05 -07:00
parent a8481c1670
commit e074058faa
7 changed files with 261 additions and 47 deletions

View File

@@ -1,6 +1,8 @@
import torch
import torch.nn as nn
from toolkit.models.zipper_resampler import ContextualAlphaMask
# Conv1d MLP
# MLP that can alternately be used as a conv1d on dim 1
@@ -86,46 +88,7 @@ class ZipperBlock(nn.Module):
return x
class ContextualAlphaMask(nn.Module):
def __init__(
self,
dim: int = 768,
):
super(ContextualAlphaMask, self).__init__()
self.dim = dim
half_dim = dim // 2
quarter_dim = dim // 4
self.fc1 = nn.Linear(self.dim, self.dim)
self.fc2 = nn.Linear(self.dim, half_dim)
self.norm1 = nn.LayerNorm(half_dim)
self.fc3 = nn.Linear(half_dim, half_dim)
self.fc4 = nn.Linear(half_dim, quarter_dim)
self.norm2 = nn.LayerNorm(quarter_dim)
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
self.fc6 = nn.Linear(quarter_dim, 1)
# set fc6 weights to near zero
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
self.act_fn = nn.GELU()
def forward(self, x):
# x = (batch_size, 77, 768)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.norm1(x)
x = self.act_fn(x)
x = self.fc3(x)
x = self.act_fn(x)
x = self.fc4(x)
x = self.norm2(x)
x = self.act_fn(x)
x = self.fc5(x)
x = self.act_fn(x)
x = self.fc6(x)
x = torch.sigmoid(x)
return x

View File

@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
from typing import TYPE_CHECKING
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule
@@ -26,7 +27,7 @@ class InstantLoRAMidModule(torch.nn.Module):
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.zip = ZipperBlock(
self.zip = ZipperModule(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.dim,
@@ -71,7 +72,7 @@ class InstantLoRAModule(torch.nn.Module):
sd: 'StableDiffusion'
):
super(InstantLoRAModule, self).__init__()
self.linear = torch.nn.Linear(2, 1)
# self.linear = torch.nn.Linear(2, 1)
self.sd_ref = weakref.ref(sd)
self.dim = sd.network.lora_dim
self.vision_hidden_size = vision_hidden_size
@@ -83,6 +84,15 @@ class InstantLoRAModule(torch.nn.Module):
# disable merging in. It is slower on inference
self.sd_ref().network.can_merge_in = False
self.resampler = ZipperResampler(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.vision_hidden_size,
out_tokens=self.vision_tokens,
hidden_size=self.vision_hidden_size,
hidden_tokens=self.vision_tokens
)
self.ilora_modules = torch.nn.ModuleList()
lora_modules = self.sd_ref().network.get_all_modules()
@@ -99,5 +109,7 @@ class InstantLoRAModule(torch.nn.Module):
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
def forward(self, x):
return self.linear(x)
def forward(self, img_embeds):
img_embeds = self.resampler(img_embeds)
self.img_embeds = img_embeds

View File

@@ -0,0 +1,171 @@
import torch
import torch.nn as nn
class ContextualAlphaMask(nn.Module):
def __init__(
self,
dim: int = 768,
):
super(ContextualAlphaMask, self).__init__()
self.dim = dim
half_dim = dim // 2
quarter_dim = dim // 4
self.fc1 = nn.Linear(self.dim, self.dim)
self.fc2 = nn.Linear(self.dim, half_dim)
self.norm1 = nn.LayerNorm(half_dim)
self.fc3 = nn.Linear(half_dim, half_dim)
self.fc4 = nn.Linear(half_dim, quarter_dim)
self.norm2 = nn.LayerNorm(quarter_dim)
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
self.fc6 = nn.Linear(quarter_dim, 1)
# set fc6 weights to near zero
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
self.act_fn = nn.GELU()
def forward(self, x):
# x = (batch_size, 77, 768)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.norm1(x)
x = self.act_fn(x)
x = self.fc3(x)
x = self.act_fn(x)
x = self.fc4(x)
x = self.norm2(x)
x = self.act_fn(x)
x = self.fc5(x)
x = self.act_fn(x)
x = self.fc6(x)
x = torch.sigmoid(x)
return x
class ZipperModule(nn.Module):
def __init__(
self,
in_size,
in_tokens,
out_size,
out_tokens,
hidden_size,
hidden_tokens,
use_residual=False,
):
super().__init__()
self.in_size = in_size
self.in_tokens = in_tokens
self.out_size = out_size
self.out_tokens = out_tokens
self.hidden_size = hidden_size
self.hidden_tokens = hidden_tokens
self.use_residual = use_residual
self.act_fn = nn.GELU()
self.layernorm = nn.LayerNorm(self.in_size)
self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1)
# act
self.fc1 = nn.Linear(self.in_size, self.hidden_size)
# act
self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1)
# act
self.fc2 = nn.Linear(self.hidden_size, self.out_size)
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.conv1(x)
x = self.act_fn(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.conv2(x)
x = self.act_fn(x)
x = self.fc2(x)
if self.use_residual:
x = x + residual
return x
class ZipperResampler(nn.Module):
def __init__(
self,
in_size,
in_tokens,
out_size,
out_tokens,
hidden_size,
hidden_tokens,
num_blocks=1,
is_conv_input=False,
):
super().__init__()
self.is_conv_input = is_conv_input
module_list = []
for i in range(num_blocks):
this_in_size = in_size
this_in_tokens = in_tokens
this_out_size = out_size
this_out_tokens = out_tokens
this_hidden_size = hidden_size
this_hidden_tokens = hidden_tokens
use_residual = False
# maintain middle sizes as hidden_size
if i == 0: # first block
this_in_size = in_size
this_in_tokens = in_tokens
if num_blocks == 1:
this_out_size = out_size
this_out_tokens = out_tokens
else:
this_out_size = hidden_size
this_out_tokens = hidden_tokens
elif i == num_blocks - 1: # last block
this_out_size = out_size
this_out_tokens = out_tokens
if num_blocks == 1:
this_in_size = in_size
this_in_tokens = in_tokens
else:
this_in_size = hidden_size
this_in_tokens = hidden_tokens
else: # middle blocks
this_out_size = hidden_size
this_out_tokens = hidden_tokens
this_in_size = hidden_size
this_in_tokens = hidden_tokens
use_residual = True
module_list.append(ZipperModule(
in_size=this_in_size,
in_tokens=this_in_tokens,
out_size=this_out_size,
out_tokens=this_out_tokens,
hidden_size=this_hidden_size,
hidden_tokens=this_hidden_tokens,
use_residual=use_residual
))
self.blocks = nn.ModuleList(module_list)
self.ctx_alpha = ContextualAlphaMask(
dim=out_size,
)
def forward(self, x):
if self.is_conv_input:
# flatten
x = x.view(x.size(0), x.size(1), -1)
# rearrange to (batch, tokens, size)
x = x.permute(0, 2, 1)
for block in self.blocks:
x = block(x)
alpha = self.ctx_alpha(x)
return x * alpha