From fc34a69bec323828432772afb9ee91f75584501b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 9 Sep 2024 16:24:46 -0600 Subject: [PATCH] Ignore guidance embed when full tuning flux. adjust block scaler to decat to 1.0. Add MLP resampler for reducing vision adapter tokens --- toolkit/models/vd_adapter.py | 53 +++++++++++++++++++++++++------ toolkit/stable_diffusion_model.py | 10 +++--- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index b4c0f5f0..11612134 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -20,25 +20,39 @@ if TYPE_CHECKING: from toolkit.custom_adapter import CustomAdapter -class MLP(nn.Module): - def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True): +class MLPR(nn.Module): # MLP with reshaping + def __init__( + self, + in_dim, + in_channels, + out_dim, + out_channels, + hidden_dim, + hidden_channels, + use_residual=True + ): super().__init__() if use_residual: assert in_dim == out_dim - self.layernorm = nn.LayerNorm(in_dim) + # dont normalize if using conv + self.layer_norm = nn.LayerNorm(in_dim) + self.fc1 = nn.Linear(in_dim, hidden_dim) + self.conv1 = nn.Conv1d(in_channels, hidden_channels, 1) self.fc2 = nn.Linear(hidden_dim, out_dim) - self.dropout = nn.Dropout(dropout) + self.conv2 = nn.Conv1d(hidden_channels, out_channels, 1) + self.use_residual = use_residual self.act_fn = nn.GELU() def forward(self, x): residual = x - x = self.layernorm(x) + x = self.layer_norm(x) x = self.fc1(x) + x = self.conv1(x) x = self.act_fn(x) x = self.fc2(x) - x = self.dropout(x) + x = self.conv2(x) if self.use_residual: x = x + residual return x @@ -388,7 +402,8 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): adapter_hidden_states = self.conditional_embeds block_scaler = self.adapter_ref().block_scaler if block_scaler is not None: - block_scaler = block_scaler[self.block_idx] + # add 1 to block scaler so we can decay its weight to 1.0 + block_scaler = block_scaler[self.block_idx] + 1.0 if adapter_hidden_states.shape[0] < batch_size: adapter_hidden_states = torch.cat([ @@ -620,7 +635,7 @@ class VisionDirectAdapter(torch.nn.Module): num_modules = len(self.adapter_modules) if self.config.train_scaler: - self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to( + self.block_scaler = torch.nn.Parameter(torch.tensor([0.0] * num_modules).to( dtype=torch.float32, device=self.sd_ref().device_torch )) @@ -629,6 +644,24 @@ class VisionDirectAdapter(torch.nn.Module): else: self.block_scaler = None + if self.config.num_tokens is not None: + image_encoder_state_dict = self.adapter_ref().vision_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + max_seq_len = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + max_seq_len = int( + image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + self.resampler = MLPR( + in_dim=self.token_size, + in_channels=max_seq_len, + out_dim=self.token_size, + out_channels=self.config.num_tokens, + hidden_dim=self.token_size, + hidden_channels=max_seq_len, + use_residual=False + ) + def state_dict(self, destination=None, prefix='', keep_vars=False): if self.config.train_scaler: # only return the block scaler @@ -648,6 +681,8 @@ class VisionDirectAdapter(torch.nn.Module): # todo remove this when we have a real solution if self.block_scaler is not None and self.block_scaler.dtype != torch.float32: self.block_scaler.data = self.block_scaler.data.to(torch.float32) + if self.config.num_tokens is not None: + input = self.resampler(input) return input def to(self, *args, **kwargs): @@ -659,6 +694,4 @@ class VisionDirectAdapter(torch.nn.Module): def post_weight_update(self): # force block scaler to be mean of 1 - # if self.block_scaler is not None: - # self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean() pass diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 92ea2de7..651a92e6 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -2215,11 +2215,11 @@ class StableDiffusion: # named_params[name] = param # train the guidance embedding - if self.unet.config.guidance_embeds: - transformer: FluxTransformer2DModel = self.unet - for name, param in transformer.time_text_embed.named_parameters(recurse=True, - prefix=f"{SD_PREFIX_UNET}"): - named_params[name] = param + # if self.unet.config.guidance_embeds: + # transformer: FluxTransformer2DModel = self.unet + # for name, param in transformer.time_text_embed.named_parameters(recurse=True, + # prefix=f"{SD_PREFIX_UNET}"): + # named_params[name] = param for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):