mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP - adding support for flux DoRA and ip adapter training
This commit is contained in:
@@ -5,6 +5,8 @@ import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import Transformer2DModel
|
||||
from diffusers.models.attention_processor import apply_rope
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.modules.module import T
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
@@ -26,6 +28,7 @@ from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resa
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
from diffusers import FluxTransformer2DModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
@@ -234,6 +237,165 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
# return super(CustomIPAttentionProcessor, self)._apply(fn)
|
||||
|
||||
|
||||
class CustomIPFluxAttnProcessor2_0(torch.nn.Module):
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False,
|
||||
full_token_scaler=False):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.train_scaler = train_scaler
|
||||
self.num_tokens = num_tokens
|
||||
if train_scaler:
|
||||
if full_token_scaler:
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999)
|
||||
else:
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999)
|
||||
# self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler.requires_grad_(True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
is_active = self.adapter_ref().is_active
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# will be none if disabled
|
||||
if not is_active:
|
||||
ip_hidden_states = None
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
# just strip it for now?
|
||||
image_rotary_emb = image_rotary_emb[:, :, :-self.num_tokens, :, :, :]
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
# YiYi to-do: update uising apply_rotary_emb
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# do ip adapter
|
||||
# will be none if disabled
|
||||
if ip_hidden_states is not None:
|
||||
# apply scaler
|
||||
if self.train_scaler:
|
||||
weight = self.ip_scaler
|
||||
# reshape to (1, self.num_tokens, 1)
|
||||
weight = weight.view(1, -1, 1)
|
||||
ip_hidden_states = ip_hidden_states * weight
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
scale = self.scale
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
@@ -377,6 +539,7 @@ class IPAdapter(torch.nn.Module):
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
is_pixart = sd.is_pixart
|
||||
is_flux = sd.is_flux
|
||||
if adapter_config.type == 'ip':
|
||||
# ip-adapter
|
||||
image_proj_model = ImageProjModel(
|
||||
@@ -393,7 +556,10 @@ class IPAdapter(torch.nn.Module):
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
if is_flux:
|
||||
dim = 1280
|
||||
else:
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith(
|
||||
'convnext') else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
@@ -406,14 +572,14 @@ class IPAdapter(torch.nn.Module):
|
||||
max_seq_len = int(
|
||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
output_dim = sd.unet.config['cross_attention_dim']
|
||||
|
||||
if is_pixart:
|
||||
if is_pixart or is_flux:
|
||||
# heads = 20
|
||||
heads = 20
|
||||
# dim = 4096
|
||||
dim = 1280
|
||||
output_dim = 4096
|
||||
else:
|
||||
output_dim = sd.unet.config['cross_attention_dim']
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
in_tokens = 16 * 16
|
||||
@@ -481,7 +647,14 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
# cross attention
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||
elif is_flux:
|
||||
transformer: FluxTransformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
|
||||
|
||||
# single transformer blocks do not have cross attn
|
||||
# for i, module in transformer.single_transformer_blocks.named_children():
|
||||
# attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
@@ -502,8 +675,11 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
if block_name not in blocks:
|
||||
blocks.append(block_name)
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if is_flux:
|
||||
cross_attention_dim = None
|
||||
else:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -513,30 +689,57 @@ class IPAdapter(torch.nn.Module):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
elif name.startswith("transformer"):
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
if is_flux:
|
||||
hidden_size = 3072
|
||||
else:
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
if cross_attention_dim is None and not is_flux:
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
|
||||
# if quantized, we need to scale the weights
|
||||
if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux:
|
||||
# is quantized
|
||||
|
||||
k_weight = torch.randn(hidden_size, hidden_size) * 0.01
|
||||
v_weight = torch.randn(hidden_size, hidden_size) * 0.01
|
||||
k_weight = k_weight.to(self.sd_ref().torch_dtype)
|
||||
v_weight = v_weight.to(self.sd_ref().torch_dtype)
|
||||
else:
|
||||
k_weight = unet_sd[layer_name + ".to_k.weight"]
|
||||
v_weight = unet_sd[layer_name + ".to_v.weight"]
|
||||
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
"to_k_ip.weight": k_weight,
|
||||
"to_v_ip.weight": v_weight
|
||||
}
|
||||
|
||||
attn_procs[name] = CustomIPAttentionProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self,
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler,
|
||||
# full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
|
||||
full_token_scaler=False
|
||||
)
|
||||
if self.sd_ref().is_pixart:
|
||||
if is_flux:
|
||||
attn_procs[name] = CustomIPFluxAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self,
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler,
|
||||
full_token_scaler=False
|
||||
)
|
||||
else:
|
||||
attn_procs[name] = CustomIPAttentionProcessor(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self,
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler,
|
||||
# full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler
|
||||
full_token_scaler=False
|
||||
)
|
||||
if self.sd_ref().is_pixart or self.sd_ref().is_flux:
|
||||
# pixart is much more sensitive
|
||||
weights = {
|
||||
"to_k_ip.weight": weights["to_k_ip.weight"] * 0.01,
|
||||
@@ -558,6 +761,16 @@ class IPAdapter(torch.nn.Module):
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
elif self.sd_ref().is_flux:
|
||||
# we have to set them ourselves
|
||||
transformer: FluxTransformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
@@ -653,7 +866,7 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
def set_scale(self, scale):
|
||||
self.current_scale = scale
|
||||
if not self.sd_ref().is_pixart:
|
||||
if not self.sd_ref().is_pixart and not self.sd_ref().is_flux:
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@@ -6,6 +6,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import TYPE_CHECKING, Union, List
|
||||
|
||||
from optimum.quanto import QBytesTensor, QTensor
|
||||
|
||||
from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -89,6 +91,7 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
|
||||
# m = Magnitude column-wise across output dimension
|
||||
weight = self.get_orig_weight()
|
||||
weight = weight.to(self.lora_up.weight.device, dtype=self.lora_up.weight.dtype)
|
||||
lora_weight = self.lora_up.weight @ self.lora_down.weight
|
||||
weight_norm = self._get_weight_norm(weight, lora_weight)
|
||||
self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True)
|
||||
@@ -99,7 +102,11 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
# del self.org_module
|
||||
|
||||
def get_orig_weight(self):
|
||||
return self.org_module[0].weight.data.detach()
|
||||
weight = self.org_module[0].weight
|
||||
if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor):
|
||||
return weight.dequantize().data.detach()
|
||||
else:
|
||||
return weight.data.detach()
|
||||
|
||||
def get_orig_bias(self):
|
||||
if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None:
|
||||
@@ -126,6 +133,7 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
|
||||
# magnitude = self.lora_magnitude_vector[active_adapter]
|
||||
weight = self.get_orig_weight()
|
||||
weight = weight.to(scaled_lora_weight.device, dtype=scaled_lora_weight.dtype)
|
||||
weight_norm = self._get_weight_norm(weight, scaled_lora_weight)
|
||||
# see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353)
|
||||
# "[...] we suggest treating ||V +∆V ||_c in
|
||||
@@ -135,4 +143,4 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
# during backpropagation"
|
||||
weight_norm = weight_norm.detach()
|
||||
dora_weight = transpose(weight + scaled_lora_weight, False)
|
||||
return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight)
|
||||
return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x.to(dora_weight.dtype), dora_weight)
|
||||
|
||||
@@ -293,7 +293,7 @@ class ToolkitModuleMixin:
|
||||
# todo handle our batch split scalers for slider training. For now take the mean of them
|
||||
scale = multiplier.mean()
|
||||
scaled_lora_weight = lora_weight * scale
|
||||
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight)
|
||||
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight).to(org_forwarded.dtype)
|
||||
|
||||
try:
|
||||
x = org_forwarded + scaled_lora_output
|
||||
|
||||
Reference in New Issue
Block a user