mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for training vision direct weight adapters
This commit is contained in:
@@ -1584,6 +1584,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.scaler.update()
|
||||
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.post_weight_update()
|
||||
if self.ema is not None:
|
||||
with self.timer('ema_update'):
|
||||
self.ema.update()
|
||||
|
||||
@@ -402,7 +402,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'sv_adapter' in state_dict:
|
||||
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
|
||||
if 'vision_encoder' in state_dict:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
|
||||
if 'fuse_module' in state_dict:
|
||||
@@ -881,10 +881,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
for attn_processor in self.te_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
elif self.config.type == 'vision_direct':
|
||||
for attn_processor in self.vd_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
if self.config.train_scaler:
|
||||
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
|
||||
yield self.vd_adapter.block_scaler
|
||||
else:
|
||||
for attn_processor in self.vd_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'te_augmenter':
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
@@ -908,4 +912,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
additional[k] = v
|
||||
additional['clip_layer'] = self.config.clip_layer
|
||||
additional['image_encoder_arch'] = self.config.head_dim
|
||||
return additional
|
||||
return additional
|
||||
|
||||
def post_weight_update(self):
|
||||
# do any kind of updates after the weight update
|
||||
if self.config.type == 'vision_direct':
|
||||
self.vd_adapter.post_weight_update()
|
||||
pass
|
||||
@@ -5,9 +5,12 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING, Optional
|
||||
from collections import OrderedDict
|
||||
|
||||
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
@@ -286,7 +289,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
|
||||
adapter_hidden_size=None, has_bias=False, **kwargs):
|
||||
adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@@ -298,6 +301,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
|
||||
self.adapter_hidden_size = adapter_hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
|
||||
self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
|
||||
@@ -382,6 +386,10 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
|
||||
# begin ip adapter
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
block_scaler = self.adapter_ref().block_scaler
|
||||
if block_scaler is not None:
|
||||
block_scaler = block_scaler[self.block_idx]
|
||||
|
||||
if adapter_hidden_states.shape[0] < batch_size:
|
||||
adapter_hidden_states = torch.cat([
|
||||
self.unconditional_embeds,
|
||||
@@ -407,6 +415,15 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
|
||||
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
vd_hidden_states = vd_hidden_states.to(query.dtype)
|
||||
|
||||
# scale to block scaler
|
||||
if block_scaler is not None:
|
||||
orig_dtype = vd_hidden_states.dtype
|
||||
if block_scaler.dtype != vd_hidden_states.dtype:
|
||||
vd_hidden_states = vd_hidden_states.to(block_scaler.dtype)
|
||||
vd_hidden_states = vd_hidden_states * block_scaler
|
||||
if block_scaler.dtype != orig_dtype:
|
||||
vd_hidden_states = vd_hidden_states.to(orig_dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * vd_hidden_states
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
@@ -437,6 +454,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
is_flux = sd.is_flux
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.config: AdapterConfig = adapter.config
|
||||
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
||||
|
||||
if adapter.config.clip_layer == "image_embeds":
|
||||
@@ -469,6 +487,8 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
current_idx = 0
|
||||
|
||||
for name in attn_processor_keys:
|
||||
if is_flux:
|
||||
cross_attention_dim = None
|
||||
@@ -550,6 +570,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size,
|
||||
has_bias=False,
|
||||
block_idx=current_idx
|
||||
)
|
||||
else:
|
||||
attn_procs[name] = VisionDirectAdapterAttnProcessor(
|
||||
@@ -560,7 +581,9 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
adapter_hidden_size=self.token_size,
|
||||
has_bias=False,
|
||||
)
|
||||
current_idx += 1
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
|
||||
if self.sd_ref().is_pixart:
|
||||
# we have to set them ourselves
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
@@ -595,14 +618,25 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
# # add the mlp layer
|
||||
# self.mlp = MLP(
|
||||
# in_dim=self.token_size,
|
||||
# out_dim=self.token_size,
|
||||
# hidden_dim=self.token_size,
|
||||
# # dropout=0.1,
|
||||
# use_residual=True
|
||||
# )
|
||||
num_modules = len(self.adapter_modules)
|
||||
if self.config.train_scaler:
|
||||
self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to(
|
||||
dtype=torch.float32,
|
||||
device=self.sd_ref().device_torch
|
||||
))
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
self.block_scaler.requires_grad = True
|
||||
else:
|
||||
self.block_scaler = None
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
if self.config.train_scaler:
|
||||
# only return the block scaler
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination[prefix + 'block_scaler'] = self.block_scaler
|
||||
return destination
|
||||
return super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# make a getter to see if is active
|
||||
@property
|
||||
@@ -610,5 +644,21 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
def forward(self, input):
|
||||
# return self.mlp(input)
|
||||
# block scaler keeps moving dtypes. make sure it is float32 here
|
||||
# todo remove this when we have a real solution
|
||||
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
return input
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
if self.block_scaler is not None:
|
||||
if self.block_scaler.dtype != torch.float32:
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
return self
|
||||
|
||||
def post_weight_update(self):
|
||||
# force block scaler to be mean of 1
|
||||
if self.block_scaler is not None:
|
||||
self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean()
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user