mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Added support for vision direct adapter for flux
This commit is contained in:
@@ -62,6 +62,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16"
|
||||
|
||||
self.do_grad_scale = True
|
||||
if self.is_fine_tuning:
|
||||
self.do_grad_scale = False
|
||||
if self.adapter_config is not None:
|
||||
if self.adapter_config.train:
|
||||
self.do_grad_scale = False
|
||||
|
||||
if self.train_config.dtype in ["fp16", "float16"]:
|
||||
# patch the scaler to allow fp16 training
|
||||
org_unscale_grads = self.scaler._unscale_grads_
|
||||
@@ -1519,7 +1526,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# if self.is_bfloat:
|
||||
# loss.backward()
|
||||
# else:
|
||||
if self.is_fine_tuning:
|
||||
if not self.do_grad_scale:
|
||||
loss.backward()
|
||||
else:
|
||||
self.scaler.scale(loss).backward()
|
||||
@@ -1528,7 +1535,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if not self.is_grad_accumulation_step:
|
||||
# fix this for multi params
|
||||
if self.train_config.optimizer != 'adafactor':
|
||||
if not self.is_fine_tuning:
|
||||
if self.do_grad_scale:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
if isinstance(self.params[0], dict):
|
||||
for i in range(len(self.params)):
|
||||
@@ -1538,7 +1545,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# only step if we are not accumulating
|
||||
with self.timer('optimizer_step'):
|
||||
# self.optimizer.step()
|
||||
if self.is_fine_tuning:
|
||||
if not self.do_grad_scale:
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.scaler.step(self.optimizer)
|
||||
|
||||
@@ -852,6 +852,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
if self.adapter_type == 'te_augmenter':
|
||||
clip_image_embeds = self.te_augmenter(clip_image_embeds)
|
||||
|
||||
if self.adapter_type == 'vision_direct':
|
||||
clip_image_embeds = self.vd_adapter(clip_image_embeds)
|
||||
|
||||
# save them to the conditional and unconditional
|
||||
try:
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
|
||||
@@ -4,9 +4,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
from typing import Union, TYPE_CHECKING, Optional
|
||||
|
||||
from diffusers import Transformer2DModel
|
||||
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
@@ -16,6 +16,30 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, hidden_dim, dropout=0.1, use_residual=True):
|
||||
super().__init__()
|
||||
if use_residual:
|
||||
assert in_dim == out_dim
|
||||
self.layernorm = nn.LayerNorm(in_dim)
|
||||
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
||||
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.use_residual = use_residual
|
||||
self.act_fn = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
x = self.layernorm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act_fn(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
if self.use_residual:
|
||||
x = x + residual
|
||||
return x
|
||||
|
||||
class AttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
@@ -258,6 +282,168 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomFluxVDAttnProcessor2_0(torch.nn.Module):
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None,
|
||||
adapter_hidden_size=None, has_bias=False, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.adapter_hidden_size = adapter_hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
|
||||
self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias)
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
# return False
|
||||
|
||||
@property
|
||||
def unconditional_embeds(self):
|
||||
return self.adapter_ref().adapter_ref().unconditional_embeds
|
||||
|
||||
@property
|
||||
def conditional_embeds(self):
|
||||
return self.adapter_ref().adapter_ref().conditional_embeds
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
is_active = self.adapter_ref().is_active
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
# YiYi to-do: update uising apply_rotary_emb
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
from diffusers.models.embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# do ip adapter
|
||||
# will be none if disabled
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
if adapter_hidden_states.shape[0] < batch_size:
|
||||
adapter_hidden_states = torch.cat([
|
||||
self.unconditional_embeds,
|
||||
adapter_hidden_states
|
||||
], dim=0)
|
||||
# if it is image embeds, we need to add a 1 dim at inx 1
|
||||
if len(adapter_hidden_states.shape) == 2:
|
||||
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
|
||||
# conditional_batch_size = adapter_hidden_states.shape[0]
|
||||
# conditional_query = query
|
||||
|
||||
# for ip-adapter
|
||||
vd_key = self.to_k_adapter(adapter_hidden_states)
|
||||
vd_value = self.to_v_adapter(adapter_hidden_states)
|
||||
|
||||
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
vd_hidden_states = F.scaled_dot_product_attention(
|
||||
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
vd_hidden_states = vd_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * vd_hidden_states
|
||||
|
||||
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class VisionDirectAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -267,6 +453,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
):
|
||||
super(VisionDirectAdapter, self).__init__()
|
||||
is_pixart = sd.is_pixart
|
||||
is_flux = sd.is_flux
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
||||
@@ -290,11 +477,22 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
# cross attention
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||
|
||||
elif is_flux:
|
||||
transformer: FluxTransformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
|
||||
|
||||
# single transformer blocks do not have cross attn
|
||||
# for i, module in transformer.single_transformer_blocks.named_children():
|
||||
# attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim']
|
||||
if is_flux:
|
||||
cross_attention_dim = None
|
||||
else:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -304,22 +502,27 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
elif name.startswith("transformer"):
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
if is_flux:
|
||||
hidden_size = 3072
|
||||
else:
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
if cross_attention_dim is None and not is_flux:
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
|
||||
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
|
||||
# if is_pixart:
|
||||
# to_k_bias = unet_sd[layer_name + ".to_k.bias"]
|
||||
# to_v_bias = unet_sd[layer_name + ".to_v.bias"]
|
||||
# else:
|
||||
# to_k_bias = None
|
||||
# to_v_bias = None
|
||||
if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux:
|
||||
# is quantized
|
||||
|
||||
to_k_adapter = torch.randn(hidden_size, hidden_size) * 0.01
|
||||
to_v_adapter = torch.randn(hidden_size, hidden_size) * 0.01
|
||||
to_k_adapter = to_k_adapter.to(self.sd_ref().torch_dtype)
|
||||
to_v_adapter = to_v_adapter.to(self.sd_ref().torch_dtype)
|
||||
else:
|
||||
to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
|
||||
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
|
||||
|
||||
# add zero padding to the adapter
|
||||
if to_k_adapter.shape[1] < self.token_size:
|
||||
@@ -337,21 +540,6 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
# if is_pixart:
|
||||
# to_k_bias = torch.cat([
|
||||
# to_k_bias,
|
||||
# torch.zeros(self.token_size - to_k_adapter.shape[1]).to(
|
||||
# to_k_adapter.device, dtype=to_k_adapter.dtype)
|
||||
# ],
|
||||
# dim=0
|
||||
# )
|
||||
# to_v_bias = torch.cat([
|
||||
# to_v_bias,
|
||||
# torch.zeros(self.token_size - to_v_adapter.shape[1]).to(
|
||||
# to_k_adapter.device, dtype=to_k_adapter.dtype)
|
||||
# ],
|
||||
# dim=0
|
||||
# )
|
||||
elif to_k_adapter.shape[1] > self.token_size:
|
||||
to_k_adapter = to_k_adapter[:, :self.token_size]
|
||||
to_v_adapter = to_v_adapter[:, :self.token_size]
|
||||
@@ -371,16 +559,26 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
}
|
||||
# if is_pixart:
|
||||
# weights["to_k_adapter.bias"] = to_k_bias
|
||||
# weights["to_v_adapter.bias"] = to_v_bias
|
||||
# weights["to_v_adapter.bias"] = to_v_bias\
|
||||
|
||||
attn_procs[name] = VisionDirectAdapterAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size,
|
||||
has_bias=False,
|
||||
)
|
||||
if is_flux:
|
||||
attn_procs[name] = CustomFluxVDAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size,
|
||||
has_bias=False,
|
||||
)
|
||||
else:
|
||||
attn_procs[name] = VisionDirectAdapterAttnProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size,
|
||||
has_bias=False,
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
if self.sd_ref().is_pixart:
|
||||
# we have to set them ourselves
|
||||
@@ -393,14 +591,33 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
] + [
|
||||
transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks))
|
||||
])
|
||||
elif self.sd_ref().is_flux:
|
||||
# we have to set them ourselves
|
||||
transformer: FluxTransformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
# add the mlp layer
|
||||
self.mlp = MLP(
|
||||
in_dim=self.token_size,
|
||||
out_dim=self.token_size,
|
||||
hidden_dim=self.token_size,
|
||||
# dropout=0.1,
|
||||
use_residual=True
|
||||
)
|
||||
|
||||
# make a getter to see if is active
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
return self.mlp(input)
|
||||
|
||||
Reference in New Issue
Block a user