mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
172 lines
5.0 KiB
Python
172 lines
5.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class ContextualAlphaMask(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int = 768,
|
|
):
|
|
super(ContextualAlphaMask, self).__init__()
|
|
self.dim = dim
|
|
|
|
half_dim = dim // 2
|
|
quarter_dim = dim // 4
|
|
|
|
self.fc1 = nn.Linear(self.dim, self.dim)
|
|
self.fc2 = nn.Linear(self.dim, half_dim)
|
|
self.norm1 = nn.LayerNorm(half_dim)
|
|
self.fc3 = nn.Linear(half_dim, half_dim)
|
|
self.fc4 = nn.Linear(half_dim, quarter_dim)
|
|
self.norm2 = nn.LayerNorm(quarter_dim)
|
|
self.fc5 = nn.Linear(quarter_dim, quarter_dim)
|
|
self.fc6 = nn.Linear(quarter_dim, 1)
|
|
# set fc6 weights to near zero
|
|
self.fc6.weight.data.normal_(mean=0.0, std=0.0001)
|
|
self.act_fn = nn.GELU()
|
|
|
|
def forward(self, x):
|
|
# x = (batch_size, 77, 768)
|
|
x = self.fc1(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc2(x)
|
|
x = self.norm1(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc3(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc4(x)
|
|
x = self.norm2(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc5(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc6(x)
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
|
|
class ZipperModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_size,
|
|
in_tokens,
|
|
out_size,
|
|
out_tokens,
|
|
hidden_size,
|
|
hidden_tokens,
|
|
use_residual=False,
|
|
):
|
|
super().__init__()
|
|
self.in_size = in_size
|
|
self.in_tokens = in_tokens
|
|
self.out_size = out_size
|
|
self.out_tokens = out_tokens
|
|
self.hidden_size = hidden_size
|
|
self.hidden_tokens = hidden_tokens
|
|
self.use_residual = use_residual
|
|
|
|
self.act_fn = nn.GELU()
|
|
self.layernorm = nn.LayerNorm(self.in_size)
|
|
|
|
self.conv1 = nn.Conv1d(self.in_tokens, self.hidden_tokens, 1)
|
|
# act
|
|
self.fc1 = nn.Linear(self.in_size, self.hidden_size)
|
|
# act
|
|
self.conv2 = nn.Conv1d(self.hidden_tokens, self.out_tokens, 1)
|
|
# act
|
|
self.fc2 = nn.Linear(self.hidden_size, self.out_size)
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
x = self.layernorm(x)
|
|
x = self.conv1(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc1(x)
|
|
x = self.act_fn(x)
|
|
x = self.conv2(x)
|
|
x = self.act_fn(x)
|
|
x = self.fc2(x)
|
|
if self.use_residual:
|
|
x = x + residual
|
|
return x
|
|
|
|
|
|
class ZipperResampler(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_size,
|
|
in_tokens,
|
|
out_size,
|
|
out_tokens,
|
|
hidden_size,
|
|
hidden_tokens,
|
|
num_blocks=1,
|
|
is_conv_input=False,
|
|
):
|
|
super().__init__()
|
|
self.is_conv_input = is_conv_input
|
|
|
|
module_list = []
|
|
for i in range(num_blocks):
|
|
|
|
this_in_size = in_size
|
|
this_in_tokens = in_tokens
|
|
this_out_size = out_size
|
|
this_out_tokens = out_tokens
|
|
this_hidden_size = hidden_size
|
|
this_hidden_tokens = hidden_tokens
|
|
use_residual = False
|
|
|
|
# maintain middle sizes as hidden_size
|
|
if i == 0: # first block
|
|
this_in_size = in_size
|
|
this_in_tokens = in_tokens
|
|
if num_blocks == 1:
|
|
this_out_size = out_size
|
|
this_out_tokens = out_tokens
|
|
else:
|
|
this_out_size = hidden_size
|
|
this_out_tokens = hidden_tokens
|
|
elif i == num_blocks - 1: # last block
|
|
this_out_size = out_size
|
|
this_out_tokens = out_tokens
|
|
if num_blocks == 1:
|
|
this_in_size = in_size
|
|
this_in_tokens = in_tokens
|
|
else:
|
|
this_in_size = hidden_size
|
|
this_in_tokens = hidden_tokens
|
|
else: # middle blocks
|
|
this_out_size = hidden_size
|
|
this_out_tokens = hidden_tokens
|
|
this_in_size = hidden_size
|
|
this_in_tokens = hidden_tokens
|
|
use_residual = True
|
|
|
|
module_list.append(ZipperModule(
|
|
in_size=this_in_size,
|
|
in_tokens=this_in_tokens,
|
|
out_size=this_out_size,
|
|
out_tokens=this_out_tokens,
|
|
hidden_size=this_hidden_size,
|
|
hidden_tokens=this_hidden_tokens,
|
|
use_residual=use_residual
|
|
))
|
|
|
|
self.blocks = nn.ModuleList(module_list)
|
|
|
|
self.ctx_alpha = ContextualAlphaMask(
|
|
dim=out_size,
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.is_conv_input:
|
|
# flatten
|
|
x = x.view(x.size(0), x.size(1), -1)
|
|
# rearrange to (batch, tokens, size)
|
|
x = x.permute(0, 2, 1)
|
|
|
|
for block in self.blocks:
|
|
x = block(x)
|
|
alpha = self.ctx_alpha(x)
|
|
return x * alpha
|