Ignore guidance embed when full tuning flux. adjust block scaler to decat to 1.0. Add MLP resampler for reducing vision adapter tokens

This commit is contained in:
Jaret Burkett
2024-09-09 16:24:46 -06:00
parent 279ee65177
commit fc34a69bec
2 changed files with 48 additions and 15 deletions

View File

@@ -20,25 +20,39 @@ if TYPE_CHECKING:
from toolkit.custom_adapter import CustomAdapter
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
class MLPR(nn.Module): # MLP with reshaping
def __init__(
self,
in_dim,
in_channels,
out_dim,
out_channels,
hidden_dim,
hidden_channels,
use_residual=True
):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
# dont normalize if using conv
self.layer_norm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.conv1 = nn.Conv1d(in_channels, hidden_channels, 1)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.layer_norm(x)
x = self.fc1(x)
x = self.conv1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.dropout(x)
x = self.conv2(x)
if self.use_residual:
x = x + residual
return x
@@ -388,7 +402,8 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
adapter_hidden_states = self.conditional_embeds
block_scaler = self.adapter_ref().block_scaler
if block_scaler is not None:
block_scaler = block_scaler[self.block_idx]
# add 1 to block scaler so we can decay its weight to 1.0
block_scaler = block_scaler[self.block_idx] + 1.0
if adapter_hidden_states.shape[0] < batch_size:
adapter_hidden_states = torch.cat([
@@ -620,7 +635,7 @@ class VisionDirectAdapter(torch.nn.Module):
num_modules = len(self.adapter_modules)
if self.config.train_scaler:
self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to(
self.block_scaler = torch.nn.Parameter(torch.tensor([0.0] * num_modules).to(
dtype=torch.float32,
device=self.sd_ref().device_torch
))
@@ -629,6 +644,24 @@ class VisionDirectAdapter(torch.nn.Module):
else:
self.block_scaler = None
if self.config.num_tokens is not None:
image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
max_seq_len = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
max_seq_len = int(
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
self.resampler = MLPR(
in_dim=self.token_size,
in_channels=max_seq_len,
out_dim=self.token_size,
out_channels=self.config.num_tokens,
hidden_dim=self.token_size,
hidden_channels=max_seq_len,
use_residual=False
)
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.config.train_scaler:
# only return the block scaler
@@ -648,6 +681,8 @@ class VisionDirectAdapter(torch.nn.Module):
# todo remove this when we have a real solution
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
if self.config.num_tokens is not None:
input = self.resampler(input)
return input
def to(self, *args, **kwargs):
@@ -659,6 +694,4 @@ class VisionDirectAdapter(torch.nn.Module):
def post_weight_update(self):
# force block scaler to be mean of 1
# if self.block_scaler is not None:
# self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean()
pass

View File

@@ -2215,11 +2215,11 @@ class StableDiffusion:
# named_params[name] = param
# train the guidance embedding
if self.unet.config.guidance_embeds:
transformer: FluxTransformer2DModel = self.unet
for name, param in transformer.time_text_embed.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
# if self.unet.config.guidance_embeds:
# transformer: FluxTransformer2DModel = self.unet
# for name, param in transformer.time_text_embed.named_parameters(recurse=True,
# prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
prefix=f"{SD_PREFIX_UNET}"):