mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-01 16:49:56 +00:00
Ignore guidance embed when full tuning flux. adjust block scaler to decat to 1.0. Add MLP resampler for reducing vision adapter tokens
This commit is contained in:
@@ -20,25 +20,39 @@ if TYPE_CHECKING:
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
|
||||
class MLPR(nn.Module): # MLP with reshaping
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
in_channels,
|
||||
out_dim,
|
||||
out_channels,
|
||||
hidden_dim,
|
||||
hidden_channels,
|
||||
use_residual=True
|
||||
):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = nn.LayerNorm(in_dim)
|
||||
# dont normalize if using conv
|
||||
self.layer_norm = nn.LayerNorm(in_dim)
|
||||
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.conv1 = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
x = self.conv2(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
@@ -388,7 +402,8 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
block_scaler = self.adapter_ref().block_scaler
|
||||
if block_scaler is not None:
|
||||
block_scaler = block_scaler[self.block_idx]
|
||||
# add 1 to block scaler so we can decay its weight to 1.0
|
||||
block_scaler = block_scaler[self.block_idx] + 1.0
|
||||
|
||||
if adapter_hidden_states.shape[0] < batch_size:
|
||||
adapter_hidden_states = torch.cat([
|
||||
@@ -620,7 +635,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
|
||||
num_modules = len(self.adapter_modules)
|
||||
if self.config.train_scaler:
|
||||
self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to(
|
||||
self.block_scaler = torch.nn.Parameter(torch.tensor([0.0] * num_modules).to(
|
||||
dtype=torch.float32,
|
||||
device=self.sd_ref().device_torch
|
||||
))
|
||||
@@ -629,6 +644,24 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
else:
|
||||
self.block_scaler = None
|
||||
|
||||
if self.config.num_tokens is not None:
|
||||
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
max_seq_len = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
max_seq_len = int(
|
||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
self.resampler = MLPR(
|
||||
in_dim=self.token_size,
|
||||
in_channels=max_seq_len,
|
||||
out_dim=self.token_size,
|
||||
out_channels=self.config.num_tokens,
|
||||
hidden_dim=self.token_size,
|
||||
hidden_channels=max_seq_len,
|
||||
use_residual=False
|
||||
)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
if self.config.train_scaler:
|
||||
# only return the block scaler
|
||||
@@ -648,6 +681,8 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
# todo remove this when we have a real solution
|
||||
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
if self.config.num_tokens is not None:
|
||||
input = self.resampler(input)
|
||||
return input
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
@@ -659,6 +694,4 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
|
||||
def post_weight_update(self):
|
||||
# force block scaler to be mean of 1
|
||||
# if self.block_scaler is not None:
|
||||
# self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean()
|
||||
pass
|
||||
|
||||
@@ -2215,11 +2215,11 @@ class StableDiffusion:
|
||||
# named_params[name] = param
|
||||
|
||||
# train the guidance embedding
|
||||
if self.unet.config.guidance_embeds:
|
||||
transformer: FluxTransformer2DModel = self.unet
|
||||
for name, param in transformer.time_text_embed.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
# if self.unet.config.guidance_embeds:
|
||||
# transformer: FluxTransformer2DModel = self.unet
|
||||
# for name, param in transformer.time_text_embed.named_parameters(recurse=True,
|
||||
# prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
||||
prefix=f"{SD_PREFIX_UNET}"):
|
||||
|
||||
Reference in New Issue
Block a user