Added support for training vision direct weight adapters

This commit is contained in:
Jaret Burkett
2024-09-05 10:11:44 -06:00
parent 5c8fcc8a4e
commit 3a1f464132
3 changed files with 78 additions and 16 deletions

View File

@@ -1584,6 +1584,8 @@ class SDTrainer(BaseSDTrainProcess):
self.scaler.update() self.scaler.update()
self.optimizer.zero_grad(set_to_none=True) self.optimizer.zero_grad(set_to_none=True)
if self.adapter and isinstance(self.adapter, CustomAdapter):
self.adapter.post_weight_update()
if self.ema is not None: if self.ema is not None:
with self.timer('ema_update'): with self.timer('ema_update'):
self.ema.update() self.ema.update()

View File

@@ -402,7 +402,7 @@ class CustomAdapter(torch.nn.Module):
if 'sv_adapter' in state_dict: if 'sv_adapter' in state_dict:
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict) self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
if 'vision_encoder' in state_dict and self.config.train_image_encoder: if 'vision_encoder' in state_dict:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict: if 'fuse_module' in state_dict:
@@ -881,10 +881,14 @@ class CustomAdapter(torch.nn.Module):
for attn_processor in self.te_adapter.adapter_modules: for attn_processor in self.te_adapter.adapter_modules:
yield from attn_processor.parameters(recurse) yield from attn_processor.parameters(recurse)
elif self.config.type == 'vision_direct': elif self.config.type == 'vision_direct':
for attn_processor in self.vd_adapter.adapter_modules: if self.config.train_scaler:
yield from attn_processor.parameters(recurse) # only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
if self.config.train_image_encoder: yield self.vd_adapter.block_scaler
yield from self.vision_encoder.parameters(recurse) else:
for attn_processor in self.vd_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'te_augmenter': elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse) yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder: if self.config.train_image_encoder:
@@ -908,4 +912,10 @@ class CustomAdapter(torch.nn.Module):
additional[k] = v additional[k] = v
additional['clip_layer'] = self.config.clip_layer additional['clip_layer'] = self.config.clip_layer
additional['image_encoder_arch'] = self.config.head_dim additional['image_encoder_arch'] = self.config.head_dim
return additional return additional
def post_weight_update(self):
# do any kind of updates after the weight update
if self.config.type == 'vision_direct':
self.vd_adapter.post_weight_update()
pass

View File

@@ -5,9 +5,12 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import weakref import weakref
from typing import Union, TYPE_CHECKING, Optional from typing import Union, TYPE_CHECKING, Optional
from collections import OrderedDict
from diffusers import Transformer2DModel, FluxTransformer2DModel from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
@@ -286,7 +289,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections.""" """Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
adapter_hidden_size=None, has_bias=False, **kwargs): adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs):
super().__init__() super().__init__()
if not hasattr(F, "scaled_dot_product_attention"): if not hasattr(F, "scaled_dot_product_attention"):
@@ -298,6 +301,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
self.adapter_hidden_size = adapter_hidden_size self.adapter_hidden_size = adapter_hidden_size
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
self.scale = scale self.scale = scale
self.block_idx = block_idx
self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
@@ -382,6 +386,10 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
# begin ip adapter # begin ip adapter
if self.is_active and self.conditional_embeds is not None: if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds adapter_hidden_states = self.conditional_embeds
block_scaler = self.adapter_ref().block_scaler
if block_scaler is not None:
block_scaler = block_scaler[self.block_idx]
if adapter_hidden_states.shape[0] < batch_size: if adapter_hidden_states.shape[0] < batch_size:
adapter_hidden_states = torch.cat([ adapter_hidden_states = torch.cat([
self.unconditional_embeds, self.unconditional_embeds,
@@ -407,6 +415,15 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
vd_hidden_states = vd_hidden_states.to(query.dtype) vd_hidden_states = vd_hidden_states.to(query.dtype)
# scale to block scaler
if block_scaler is not None:
orig_dtype = vd_hidden_states.dtype
if block_scaler.dtype != vd_hidden_states.dtype:
vd_hidden_states = vd_hidden_states.to(block_scaler.dtype)
vd_hidden_states = vd_hidden_states * block_scaler
if block_scaler.dtype != orig_dtype:
vd_hidden_states = vd_hidden_states.to(orig_dtype)
hidden_states = hidden_states + self.scale * vd_hidden_states hidden_states = hidden_states + self.scale * vd_hidden_states
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
@@ -437,6 +454,7 @@ class VisionDirectAdapter(torch.nn.Module):
is_flux = sd.is_flux is_flux = sd.is_flux
self.adapter_ref: weakref.ref = weakref.ref(adapter) self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref: weakref.ref = weakref.ref(sd) self.sd_ref: weakref.ref = weakref.ref(sd)
self.config: AdapterConfig = adapter.config
self.vision_model_ref: weakref.ref = weakref.ref(vision_model) self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
if adapter.config.clip_layer == "image_embeds": if adapter.config.clip_layer == "image_embeds":
@@ -469,6 +487,8 @@ class VisionDirectAdapter(torch.nn.Module):
else: else:
attn_processor_keys = list(sd.unet.attn_processors.keys()) attn_processor_keys = list(sd.unet.attn_processors.keys())
current_idx = 0
for name in attn_processor_keys: for name in attn_processor_keys:
if is_flux: if is_flux:
cross_attention_dim = None cross_attention_dim = None
@@ -550,6 +570,7 @@ class VisionDirectAdapter(torch.nn.Module):
adapter=self, adapter=self,
adapter_hidden_size=self.token_size, adapter_hidden_size=self.token_size,
has_bias=False, has_bias=False,
block_idx=current_idx
) )
else: else:
attn_procs[name] = VisionDirectAdapterAttnProcessor( attn_procs[name] = VisionDirectAdapterAttnProcessor(
@@ -560,7 +581,9 @@ class VisionDirectAdapter(torch.nn.Module):
adapter_hidden_size=self.token_size, adapter_hidden_size=self.token_size,
has_bias=False, has_bias=False,
) )
current_idx += 1
attn_procs[name].load_state_dict(weights) attn_procs[name].load_state_dict(weights)
if self.sd_ref().is_pixart: if self.sd_ref().is_pixart:
# we have to set them ourselves # we have to set them ourselves
transformer: Transformer2DModel = sd.unet transformer: Transformer2DModel = sd.unet
@@ -595,14 +618,25 @@ class VisionDirectAdapter(torch.nn.Module):
sd.unet.set_attn_processor(attn_procs) sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# # add the mlp layer num_modules = len(self.adapter_modules)
# self.mlp = MLP( if self.config.train_scaler:
# in_dim=self.token_size, self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to(
# out_dim=self.token_size, dtype=torch.float32,
# hidden_dim=self.token_size, device=self.sd_ref().device_torch
# # dropout=0.1, ))
# use_residual=True self.block_scaler.data = self.block_scaler.data.to(torch.float32)
# ) self.block_scaler.requires_grad = True
else:
self.block_scaler = None
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.config.train_scaler:
# only return the block scaler
if destination is None:
destination = OrderedDict()
destination[prefix + 'block_scaler'] = self.block_scaler
return destination
return super().state_dict(destination, prefix, keep_vars)
# make a getter to see if is active # make a getter to see if is active
@property @property
@@ -610,5 +644,21 @@ class VisionDirectAdapter(torch.nn.Module):
return self.adapter_ref().is_active return self.adapter_ref().is_active
def forward(self, input): def forward(self, input):
# return self.mlp(input) # block scaler keeps moving dtypes. make sure it is float32 here
# todo remove this when we have a real solution
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
return input return input
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
if self.block_scaler is not None:
if self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
return self
def post_weight_update(self):
# force block scaler to be mean of 1
if self.block_scaler is not None:
self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean()
pass