diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 22a82439..4264071a 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -5,6 +5,8 @@ import sys from PIL import Image from diffusers import Transformer2DModel +from diffusers.models.attention_processor import apply_rope +from torch import nn from torch.nn import Parameter from torch.nn.modules.module import T from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -26,6 +28,7 @@ from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resa from toolkit.config_modules import AdapterConfig from toolkit.prompt_utils import PromptEmbeds import weakref +from diffusers import FluxTransformer2DModel if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion @@ -234,6 +237,165 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0): # return super(CustomIPAttentionProcessor, self)._apply(fn) +class CustomIPFluxAttnProcessor2_0(torch.nn.Module): + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False, + full_token_scaler=False): + super().__init__() + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.train_scaler = train_scaler + self.num_tokens = num_tokens + if train_scaler: + if full_token_scaler: + self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.999) + else: + self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.999) + # self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999) + self.ip_scaler.requires_grad_(True) + + def __call__( + self, + attn, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + is_active = self.adapter_ref().is_active + input_ndim = hidden_states.ndim + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + context_input_ndim = encoder_hidden_states.ndim + if context_input_ndim == 4: + batch_size, channel, height, width = encoder_hidden_states.shape + encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size = encoder_hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # will be none if disabled + if not is_active: + ip_hidden_states = None + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + # just strip it for now? + image_rotary_emb = image_rotary_emb[:, :, :-self.num_tokens, :, :, :] + + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + # YiYi to-do: update uising apply_rotary_emb + # from ..embeddings import apply_rotary_emb + # query = apply_rotary_emb(query, image_rotary_emb) + # key = apply_rotary_emb(key, image_rotary_emb) + query, key = apply_rope(query, key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # do ip adapter + # will be none if disabled + if ip_hidden_states is not None: + # apply scaler + if self.train_scaler: + weight = self.ip_scaler + # reshape to (1, self.num_tokens, 1) + weight = weight.view(1, -1, 1) + ip_hidden_states = ip_hidden_states * weight + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + scale = self.scale + hidden_states = hidden_states + scale * ip_hidden_states + + + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + if context_input_ndim == 4: + encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + return hidden_states, encoder_hidden_states + + # loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py class IPAdapter(torch.nn.Module): """IP-Adapter""" @@ -377,6 +539,7 @@ class IPAdapter(torch.nn.Module): self.current_scale = 1.0 self.is_active = True is_pixart = sd.is_pixart + is_flux = sd.is_flux if adapter_config.type == 'ip': # ip-adapter image_proj_model = ImageProjModel( @@ -393,7 +556,10 @@ class IPAdapter(torch.nn.Module): ) elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 - dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 + if is_flux: + dim = 1280 + else: + dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith( 'convnext') else \ self.image_encoder.config.hidden_sizes[-1] @@ -406,14 +572,14 @@ class IPAdapter(torch.nn.Module): max_seq_len = int( image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) - output_dim = sd.unet.config['cross_attention_dim'] - - if is_pixart: + if is_pixart or is_flux: # heads = 20 heads = 20 # dim = 4096 dim = 1280 output_dim = 4096 + else: + output_dim = sd.unet.config['cross_attention_dim'] if self.config.image_encoder_arch.startswith('convnext'): in_tokens = 16 * 16 @@ -481,7 +647,14 @@ class IPAdapter(torch.nn.Module): # cross attention attn_processor_keys.append(f"transformer_blocks.{i}.attn2") + elif is_flux: + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + attn_processor_keys.append(f"transformer_blocks.{i}.attn") + # single transformer blocks do not have cross attn + # for i, module in transformer.single_transformer_blocks.named_children(): + # attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") else: attn_processor_keys = list(sd.unet.attn_processors.keys()) @@ -502,8 +675,11 @@ class IPAdapter(torch.nn.Module): if block_name not in blocks: blocks.append(block_name) - cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ - sd.unet.config['cross_attention_dim'] + if is_flux: + cross_attention_dim = None + else: + cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ + sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): @@ -513,30 +689,57 @@ class IPAdapter(torch.nn.Module): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] elif name.startswith("transformer"): - hidden_size = sd.unet.config['cross_attention_dim'] + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] else: # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") - if cross_attention_dim is None: + if cross_attention_dim is None and not is_flux: attn_procs[name] = AttnProcessor2_0() else: layer_name = name.split(".processor")[0] + + # if quantized, we need to scale the weights + if f"{layer_name}.to_k.weight._data" in unet_sd and is_flux: + # is quantized + + k_weight = torch.randn(hidden_size, hidden_size) * 0.01 + v_weight = torch.randn(hidden_size, hidden_size) * 0.01 + k_weight = k_weight.to(self.sd_ref().torch_dtype) + v_weight = v_weight.to(self.sd_ref().torch_dtype) + else: + k_weight = unet_sd[layer_name + ".to_k.weight"] + v_weight = unet_sd[layer_name + ".to_v.weight"] + weights = { - "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], - "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + "to_k_ip.weight": k_weight, + "to_v_ip.weight": v_weight } - attn_procs[name] = CustomIPAttentionProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=self.config.num_tokens, - adapter=self, - train_scaler=self.config.train_scaler or self.config.merge_scaler, - # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler - full_token_scaler=False - ) - if self.sd_ref().is_pixart: + if is_flux: + attn_procs[name] = CustomIPFluxAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler, + full_token_scaler=False + ) + else: + attn_procs[name] = CustomIPAttentionProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=self.config.num_tokens, + adapter=self, + train_scaler=self.config.train_scaler or self.config.merge_scaler, + # full_token_scaler=self.config.train_scaler # full token cannot be merged in, only use if training an actual scaler + full_token_scaler=False + ) + if self.sd_ref().is_pixart or self.sd_ref().is_flux: # pixart is much more sensitive weights = { "to_k_ip.weight": weights["to_k_ip.weight"] * 0.01, @@ -558,6 +761,16 @@ class IPAdapter(torch.nn.Module): transformer.transformer_blocks[i].attn2.processor for i in range(len(transformer.transformer_blocks)) ]) + elif self.sd_ref().is_flux: + # we have to set them ourselves + transformer: FluxTransformer2DModel = sd.unet + for i, module in transformer.transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + self.adapter_modules = torch.nn.ModuleList( + [ + transformer.transformer_blocks[i].attn.processor for i in + range(len(transformer.transformer_blocks)) + ]) else: sd.unet.set_attn_processor(attn_procs) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) @@ -653,7 +866,7 @@ class IPAdapter(torch.nn.Module): def set_scale(self, scale): self.current_scale = scale - if not self.sd_ref().is_pixart: + if not self.sd_ref().is_pixart and not self.sd_ref().is_flux: for attn_processor in self.sd_ref().unet.attn_processors.values(): if isinstance(attn_processor, CustomIPAttentionProcessor): attn_processor.scale = scale diff --git a/toolkit/models/DoRA.py b/toolkit/models/DoRA.py index fb8f4838..653575e9 100644 --- a/toolkit/models/DoRA.py +++ b/toolkit/models/DoRA.py @@ -6,6 +6,8 @@ import torch.nn as nn import torch.nn.functional as F from typing import TYPE_CHECKING, Union, List +from optimum.quanto import QBytesTensor, QTensor + from toolkit.network_mixins import ToolkitModuleMixin, ExtractableModuleMixin if TYPE_CHECKING: @@ -89,6 +91,7 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # m = Magnitude column-wise across output dimension weight = self.get_orig_weight() + weight = weight.to(self.lora_up.weight.device, dtype=self.lora_up.weight.dtype) lora_weight = self.lora_up.weight @ self.lora_down.weight weight_norm = self._get_weight_norm(weight, lora_weight) self.magnitude = nn.Parameter(weight_norm.detach().clone(), requires_grad=True) @@ -99,7 +102,11 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # del self.org_module def get_orig_weight(self): - return self.org_module[0].weight.data.detach() + weight = self.org_module[0].weight + if isinstance(weight, QTensor) or isinstance(weight, QBytesTensor): + return weight.dequantize().data.detach() + else: + return weight.data.detach() def get_orig_bias(self): if hasattr(self.org_module[0], 'bias') and self.org_module[0].bias is not None: @@ -126,6 +133,7 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # magnitude = self.lora_magnitude_vector[active_adapter] weight = self.get_orig_weight() + weight = weight.to(scaled_lora_weight.device, dtype=scaled_lora_weight.dtype) weight_norm = self._get_weight_norm(weight, scaled_lora_weight) # see section 4.3 of DoRA (https://arxiv.org/abs/2402.09353) # "[...] we suggest treating ||V +∆V ||_c in @@ -135,4 +143,4 @@ class DoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module): # during backpropagation" weight_norm = weight_norm.detach() dora_weight = transpose(weight + scaled_lora_weight, False) - return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x, dora_weight) + return (self.magnitude / weight_norm - 1).view(1, -1) * F.linear(x.to(dora_weight.dtype), dora_weight) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 14ac6161..e2c4fef6 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -293,7 +293,7 @@ class ToolkitModuleMixin: # todo handle our batch split scalers for slider training. For now take the mean of them scale = multiplier.mean() scaled_lora_weight = lora_weight * scale - scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight) + scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight).to(org_forwarded.dtype) try: x = org_forwarded + scaled_lora_output