Added support for vision direct adapter for flux

This commit is contained in:
Jaret Burkett
2024-08-26 16:27:28 -06:00
parent e127c079da
commit 3843e0d148
3 changed files with 268 additions and 41 deletions

View File

@@ -62,6 +62,13 @@ class SDTrainer(BaseSDTrainProcess):
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
self.do_grad_scale = True
if self.is_fine_tuning:
self.do_grad_scale = False
if self.adapter_config is not None:
if self.adapter_config.train:
self.do_grad_scale = False
if self.train_config.dtype in ["fp16", "float16"]:
# patch the scaler to allow fp16 training
org_unscale_grads = self.scaler._unscale_grads_
@@ -1519,7 +1526,7 @@ class SDTrainer(BaseSDTrainProcess):
# if self.is_bfloat:
# loss.backward()
# else:
if self.is_fine_tuning:
if not self.do_grad_scale:
loss.backward()
else:
self.scaler.scale(loss).backward()
@@ -1528,7 +1535,7 @@ class SDTrainer(BaseSDTrainProcess):
if not self.is_grad_accumulation_step:
# fix this for multi params
if self.train_config.optimizer != 'adafactor':
if not self.is_fine_tuning:
if self.do_grad_scale:
self.scaler.unscale_(self.optimizer)
if isinstance(self.params[0], dict):
for i in range(len(self.params)):
@@ -1538,7 +1545,7 @@ class SDTrainer(BaseSDTrainProcess):
# only step if we are not accumulating
with self.timer('optimizer_step'):
# self.optimizer.step()
if self.is_fine_tuning:
if not self.do_grad_scale:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)

View File

@@ -852,6 +852,9 @@ class CustomAdapter(torch.nn.Module):
if self.adapter_type == 'te_augmenter':
clip_image_embeds = self.te_augmenter(clip_image_embeds)
if self.adapter_type == 'vision_direct':
clip_image_embeds = self.vd_adapter(clip_image_embeds)
# save them to the conditional and unconditional
try:
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)

View File

@@ -4,9 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING
from typing import Union, TYPE_CHECKING, Optional
from diffusers import Transformer2DModel
from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
@@ -16,6 +16,30 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter
class MLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
super().__init__()
if use_residual:
assert in_dim == out_dim
self.layernorm = nn.LayerNorm(in_dim)
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
self.dropout = nn.Dropout(dropout)
self.use_residual = use_residual
self.act_fn = nn.GELU()
def forward(self, x):
residual = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x = self.dropout(x)
if self.use_residual:
x = x + residual
return x
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -258,6 +282,168 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
return hidden_states
class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
adapter_hidden_size=None, has_bias=False, **kwargs):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.hidden_size = hidden_size
self.adapter_hidden_size = adapter_hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
@property
def is_active(self):
return self.adapter_ref().is_active
# return False
@property
def unconditional_embeds(self):
return self.adapter_ref().adapter_ref().unconditional_embeds
@property
def conditional_embeds(self):
return self.adapter_ref().adapter_ref().conditional_embeds
def __call__(
self,
attn,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
is_active = self.adapter_ref().is_active
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
from diffusers.models.embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# do ip adapter
# will be none if disabled
if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds
if adapter_hidden_states.shape[0] < batch_size:
adapter_hidden_states = torch.cat([
self.unconditional_embeds,
adapter_hidden_states
], dim=0)
# if it is image embeds, we need to add a 1 dim at inx 1
if len(adapter_hidden_states.shape) == 2:
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
# conditional_batch_size = adapter_hidden_states.shape[0]
# conditional_query = query
# for ip-adapter
vd_key = self.to_k_adapter(adapter_hidden_states)
vd_value = self.to_v_adapter(adapter_hidden_states)
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
vd_hidden_states = F.scaled_dot_product_attention(
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
vd_hidden_states = vd_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * vd_hidden_states
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
class VisionDirectAdapter(torch.nn.Module):
def __init__(
self,
@@ -267,6 +453,7 @@ class VisionDirectAdapter(torch.nn.Module):
):
super(VisionDirectAdapter, self).__init__()
is_pixart = sd.is_pixart
is_flux = sd.is_flux
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref: weakref.ref = weakref.ref(sd)
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
@@ -290,11 +477,22 @@ class VisionDirectAdapter(torch.nn.Module):
# cross attention
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
elif is_flux:
transformer: FluxTransformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
# single transformer blocks do not have cross attn
# for i, module in transformer.single_transformer_blocks.named_children():
# attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim']
if is_flux:
cross_attention_dim = None
else:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
@@ -304,22 +502,27 @@ class VisionDirectAdapter(torch.nn.Module):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
hidden_size = sd.unet.config['cross_attention_dim']
if is_flux:
hidden_size = 3072
else:
hidden_size = sd.unet.config['cross_attention_dim']
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
if cross_attention_dim is None:
if cross_attention_dim is None and not is_flux:
attn_procs[name] = AttnProcessor2_0()
else:
layer_name = name.split(".processor")[0]
to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
# if is_pixart:
# to_k_bias = unet_sd[layer_name + ".to_k.bias"]
# to_v_bias = unet_sd[layer_name + ".to_v.bias"]
# else:
# to_k_bias = None
# to_v_bias = None
if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux:
# is quantized
to_k_adapter = torch.randn(hidden_size, hidden_size) * 0.01
to_v_adapter = torch.randn(hidden_size, hidden_size) * 0.01
to_k_adapter = to_k_adapter.to(self.sd_ref().torch_dtype)
to_v_adapter = to_v_adapter.to(self.sd_ref().torch_dtype)
else:
to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
# add zero padding to the adapter
if to_k_adapter.shape[1] < self.token_size:
@@ -337,21 +540,6 @@ class VisionDirectAdapter(torch.nn.Module):
],
dim=1
)
# if is_pixart:
# to_k_bias = torch.cat([
# to_k_bias,
# torch.zeros(self.token_size - to_k_adapter.shape[1]).to(
# to_k_adapter.device, dtype=to_k_adapter.dtype)
# ],
# dim=0
# )
# to_v_bias = torch.cat([
# to_v_bias,
# torch.zeros(self.token_size - to_v_adapter.shape[1]).to(
# to_k_adapter.device, dtype=to_k_adapter.dtype)
# ],
# dim=0
# )
elif to_k_adapter.shape[1] > self.token_size:
to_k_adapter = to_k_adapter[:, :self.token_size]
to_v_adapter = to_v_adapter[:, :self.token_size]
@@ -371,16 +559,26 @@ class VisionDirectAdapter(torch.nn.Module):
}
# if is_pixart:
# weights["to_k_adapter.bias"] = to_k_bias
# weights["to_v_adapter.bias"] = to_v_bias
# weights["to_v_adapter.bias"] = to_v_bias\
attn_procs[name] = VisionDirectAdapterAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
adapter=self,
adapter_hidden_size=self.token_size,
has_bias=False,
)
if is_flux:
attn_procs[name] = CustomFluxVDAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
adapter=self,
adapter_hidden_size=self.token_size,
has_bias=False,
)
else:
attn_procs[name] = VisionDirectAdapterAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
adapter=self,
adapter_hidden_size=self.token_size,
has_bias=False,
)
attn_procs[name].load_state_dict(weights)
if self.sd_ref().is_pixart:
# we have to set them ourselves
@@ -393,14 +591,33 @@ class VisionDirectAdapter(torch.nn.Module):
] + [
transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks))
])
elif self.sd_ref().is_flux:
# we have to set them ourselves
transformer: FluxTransformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn.processor for i in
range(len(transformer.transformer_blocks))
])
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
# add the mlp layer
self.mlp = MLP(
in_dim=self.token_size,
out_dim=self.token_size,
hidden_dim=self.token_size,
# dropout=0.1,
use_residual=True
)
# make a getter to see if is active
@property
def is_active(self):
return self.adapter_ref().is_active
def forward(self, input):
return input
return self.mlp(input)