Added support for training vision direct weight adapters

This commit is contained in:
Jaret Burkett
2024-09-05 10:11:44 -06:00
parent 5c8fcc8a4e
commit 3a1f464132
3 changed files with 78 additions and 16 deletions

View File

@@ -1584,6 +1584,8 @@ class SDTrainer(BaseSDTrainProcess):
self.scaler.update()
self.optimizer.zero_grad(set_to_none=True)
if self.adapter and isinstance(self.adapter, CustomAdapter):
self.adapter.post_weight_update()
if self.ema is not None:
with self.timer('ema_update'):
self.ema.update()

View File

@@ -402,7 +402,7 @@ class CustomAdapter(torch.nn.Module):
if 'sv_adapter' in state_dict:
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
if 'vision_encoder' in state_dict:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict:
@@ -881,10 +881,14 @@ class CustomAdapter(torch.nn.Module):
for attn_processor in self.te_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
elif self.config.type == 'vision_direct':
for attn_processor in self.vd_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
if self.config.train_scaler:
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
yield self.vd_adapter.block_scaler
else:
for attn_processor in self.vd_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder:
@@ -908,4 +912,10 @@ class CustomAdapter(torch.nn.Module):
additional[k] = v
additional['clip_layer'] = self.config.clip_layer
additional['image_encoder_arch'] = self.config.head_dim
return additional
return additional
def post_weight_update(self):
# do any kind of updates after the weight update
if self.config.type == 'vision_direct':
self.vd_adapter.post_weight_update()
pass

View File

@@ -5,9 +5,12 @@ import torch.nn as nn
import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING, Optional
from collections import OrderedDict
from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
@@ -286,7 +289,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
adapter_hidden_size=None, has_bias=False, **kwargs):
adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@@ -298,6 +301,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
self.adapter_hidden_size = adapter_hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.block_idx = block_idx
self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
@@ -382,6 +386,10 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
# begin ip adapter
if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds
block_scaler = self.adapter_ref().block_scaler
if block_scaler is not None:
block_scaler = block_scaler[self.block_idx]
if adapter_hidden_states.shape[0] < batch_size:
adapter_hidden_states = torch.cat([
self.unconditional_embeds,
@@ -407,6 +415,15 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
vd_hidden_states = vd_hidden_states.to(query.dtype)
# scale to block scaler
if block_scaler is not None:
orig_dtype = vd_hidden_states.dtype
if block_scaler.dtype != vd_hidden_states.dtype:
vd_hidden_states = vd_hidden_states.to(block_scaler.dtype)
vd_hidden_states = vd_hidden_states * block_scaler
if block_scaler.dtype != orig_dtype:
vd_hidden_states = vd_hidden_states.to(orig_dtype)
hidden_states = hidden_states + self.scale * vd_hidden_states
if encoder_hidden_states is not None:
@@ -437,6 +454,7 @@ class VisionDirectAdapter(torch.nn.Module):
is_flux = sd.is_flux
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref: weakref.ref = weakref.ref(sd)
self.config: AdapterConfig = adapter.config
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
if adapter.config.clip_layer == "image_embeds":
@@ -469,6 +487,8 @@ class VisionDirectAdapter(torch.nn.Module):
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
current_idx = 0
for name in attn_processor_keys:
if is_flux:
cross_attention_dim = None
@@ -550,6 +570,7 @@ class VisionDirectAdapter(torch.nn.Module):
adapter=self,
adapter_hidden_size=self.token_size,
has_bias=False,
block_idx=current_idx
)
else:
attn_procs[name] = VisionDirectAdapterAttnProcessor(
@@ -560,7 +581,9 @@ class VisionDirectAdapter(torch.nn.Module):
adapter_hidden_size=self.token_size,
has_bias=False,
)
current_idx += 1
attn_procs[name].load_state_dict(weights)
if self.sd_ref().is_pixart:
# we have to set them ourselves
transformer: Transformer2DModel = sd.unet
@@ -595,14 +618,25 @@ class VisionDirectAdapter(torch.nn.Module):
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# # add the mlp layer
# self.mlp = MLP(
# in_dim=self.token_size,
# out_dim=self.token_size,
# hidden_dim=self.token_size,
# # dropout=0.1,
# use_residual=True
# )
num_modules = len(self.adapter_modules)
if self.config.train_scaler:
self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to(
dtype=torch.float32,
device=self.sd_ref().device_torch
))
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
self.block_scaler.requires_grad = True
else:
self.block_scaler = None
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.config.train_scaler:
# only return the block scaler
if destination is None:
destination = OrderedDict()
destination[prefix + 'block_scaler'] = self.block_scaler
return destination
return super().state_dict(destination, prefix, keep_vars)
# make a getter to see if is active
@property
@@ -610,5 +644,21 @@ class VisionDirectAdapter(torch.nn.Module):
return self.adapter_ref().is_active
def forward(self, input):
# return self.mlp(input)
# block scaler keeps moving dtypes. make sure it is float32 here
# todo remove this when we have a real solution
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
return input
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
if self.block_scaler is not None:
if self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
return self
def post_weight_update(self):
# force block scaler to be mean of 1
if self.block_scaler is not None:
self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean()
pass