From 5f312cd46b8d1b2991e074bb4759a6071769256e Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 18 Apr 2025 09:59:42 -0600 Subject: [PATCH] Remove ip adapter submodule --- .gitmodules | 4 - repositories/ipadapter | 1 - ...ert_diffusers_to_comfy_transformer_only.py | 76 ++- toolkit/custom_adapter.py | 4 - toolkit/ip_adapter.py | 14 +- toolkit/models/LoRAFormer.py | 3 - toolkit/models/ilora.py | 8 +- toolkit/models/ilora2.py | 9 +- toolkit/models/te_adapter.py | 14 +- toolkit/models/te_aug_adapter.py | 6 - toolkit/models/vd_adapter.py | 2 - toolkit/reference_adapter.py | 4 - toolkit/util/ip_adapter_utils.py | 634 ++++++++++++++++++ 13 files changed, 709 insertions(+), 70 deletions(-) delete mode 160000 repositories/ipadapter create mode 100644 toolkit/util/ip_adapter_utils.py diff --git a/.gitmodules b/.gitmodules index a98073dc..e790da09 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,7 +10,3 @@ path = repositories/batch_annotator url = https://github.com/ostris/batch-annotator commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf -[submodule "repositories/ipadapter"] - path = repositories/ipadapter - url = https://github.com/tencent-ailab/IP-Adapter.git - commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14 diff --git a/repositories/ipadapter b/repositories/ipadapter deleted file mode 160000 index 5a18b1f3..00000000 --- a/repositories/ipadapter +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 5a18b1f3660acaf8bee8250692d6fb3548a19b14 diff --git a/scripts/convert_diffusers_to_comfy_transformer_only.py b/scripts/convert_diffusers_to_comfy_transformer_only.py index c773c6eb..9973c087 100644 --- a/scripts/convert_diffusers_to_comfy_transformer_only.py +++ b/scripts/convert_diffusers_to_comfy_transformer_only.py @@ -2,10 +2,14 @@ # Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file # This will only have the transformer weights, not the TEs and VAE # You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag +# You can also save with scaled 8-bit using the --do_8bit_scaled flag # -# Call like this for 8-bit transformer weights: +# Call like this for 8-bit transformer weights with stochastic rounding: # python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit # +# Call like this for 8-bit transformer weights with scaling: +# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled +# # Call like this for bf16 transformer weights: # python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors # @@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str, parser.add_argument("flux_path", type=str, help="Output path for the Flux safetensors file.") parser.add_argument("--do_8_bit", action="store_true", - help="Use 8-bit weights instead of bf16.") + help="Use 8-bit weights with stochastic rounding instead of bf16.") +parser.add_argument("--do_8bit_scaled", action="store_true", + help="Use scaled 8-bit weights instead of bf16.") args = parser.parse_args() flux_path = Path(args.flux_path) @@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")): diffusers_path = Path(os.path.join(diffusers_path, "transformer")) do_8_bit = args.do_8_bit +do_8bit_scaled = args.do_8bit_scaled + +# Don't allow both flags to be active simultaneously +if do_8_bit and do_8bit_scaled: + print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.") + exit() if not os.path.exists(flux_path.parent): os.makedirs(flux_path.parent) @@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn): return rounded.to(dtype) -# set all the keys to bf16 +# List of keys that should not be scaled (usually embedding layers and biases) +blacklist = [] for key in flux.keys(): - if do_8_bit: + if not key.endswith(".weight") or "embed" in key: + blacklist.append(key) + +# Function to scale weights for 8-bit quantization +def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn): + # Get the limits of the dtype + min_val = torch.finfo(dtype).min + max_val = torch.finfo(dtype).max + + # Only process 2D tensors that are not in the blacklist + if tensor.dim() == 2: + # Calculate the scaling factor + abs_max = torch.max(torch.abs(tensor)) + scale = abs_max / max_value + + # Scale the tensor and clip to float8 range + scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype) + + return scaled_tensor, scale + else: + # For tensors that shouldn't be scaled, just convert to float8 + return tensor.clip(min=min_val, max=max_val).to(dtype), None + + +# set all the keys to appropriate dtype +if do_8_bit: + print("Converting to 8-bit with stochastic rounding...") + for key in flux.keys(): flux[key] = stochastic_round_to( flux[key], torch.float8_e4m3fn).to('cpu') - else: +elif do_8bit_scaled: + print("Converting to scaled 8-bit...") + scales = {} + for key in tqdm.tqdm(flux.keys()): + if key.endswith(".weight") and key not in blacklist: + flux[key], scale = scale_weights_to_8bit(flux[key]) + if scale is not None: + scale_key = key[:-len(".weight")] + ".scale_weight" + scales[scale_key] = scale + else: + # For non-weight tensors or blacklisted ones, just convert without scaling + min_val = torch.finfo(torch.float8_e4m3fn).min + max_val = torch.finfo(torch.float8_e4m3fn).max + flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu') + + # Add all the scales to the flux dictionary + flux.update(scales) + + # Add a marker tensor to indicate this is a scaled fp8 model + flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn) +else: + print("Converting to bfloat16...") + for key in flux.keys(): flux[key] = flux[key].clone().to('cpu', torch.bfloat16) - - meta = OrderedDict() meta['format'] = 'pt' # date format like 2024-08-01 YYYY-MM-DD @@ -394,4 +454,4 @@ print(f"Saving to {flux_path}") safetensors.torch.save_file(flux, flux_path, metadata=meta) -print("Done.") +print("Done.") \ No newline at end of file diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 820fdc39..cdb65693 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -31,10 +31,6 @@ from toolkit.util.mask import generate_random_mask sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from collections import OrderedDict -from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ - AttnProcessor2_0 -from ipadapter.ip_adapter.ip_adapter import ImageProjModel -from ipadapter.ip_adapter.resampler import Resampler from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig from toolkit.prompt_utils import PromptEmbeds import weakref diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 4821e968..33bf2943 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -3,7 +3,6 @@ import random import torch import sys -from PIL import Image from diffusers import Transformer2DModel from torch import nn from torch.nn import Parameter @@ -12,18 +11,14 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.zipper_resampler import ZipperResampler -from toolkit.paths import REPOS_ROOT from toolkit.saving import load_ip_adapter_model from toolkit.train_tools import get_torch_dtype from toolkit.util.inverse_cfg import inverse_classifier_guidance -sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional from collections import OrderedDict -from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ - AttnProcessor2_0 -from ipadapter.ip_adapter.ip_adapter import ImageProjModel -from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler +from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel +from toolkit.resampler import Resampler from toolkit.config_modules import AdapterConfig from toolkit.prompt_utils import PromptEmbeds import weakref @@ -35,9 +30,7 @@ if TYPE_CHECKING: from transformers import ( CLIPImageProcessor, CLIPVisionModelWithProjection, - CLIPVisionModel, AutoImageProcessor, - ConvNextModel, ConvNextV2ForImageClassification, ConvNextForImageClassification, ConvNextImageProcessor @@ -48,9 +41,6 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio from transformers import ViTFeatureExtractor, ViTForImageClassification -# gradient checkpointing -from torch.utils.checkpoint import checkpoint - import torch.nn.functional as F diff --git a/toolkit/models/LoRAFormer.py b/toolkit/models/LoRAFormer.py index 78bb460d..2097a560 100644 --- a/toolkit/models/LoRAFormer.py +++ b/toolkit/models/LoRAFormer.py @@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, List, Dict, Any from toolkit.models.clip_fusion import ZipperBlock from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler import sys -from toolkit.paths import REPOS_ROOT -sys.path.append(REPOS_ROOT) -from ipadapter.ip_adapter.resampler import Resampler from collections import OrderedDict if TYPE_CHECKING: diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 33613ed3..886d263c 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -4,13 +4,7 @@ import weakref import torch import torch.nn as nn from typing import TYPE_CHECKING, List, Dict, Any -from toolkit.models.clip_fusion import ZipperBlock -from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler -import sys -from toolkit.paths import REPOS_ROOT -sys.path.append(REPOS_ROOT) -from ipadapter.ip_adapter.resampler import Resampler -from collections import OrderedDict +from toolkit.resampler import Resampler if TYPE_CHECKING: from toolkit.lora_special import LoRAModule diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py index c46bd0a6..2aba5eae 100644 --- a/toolkit/models/ilora2.py +++ b/toolkit/models/ilora2.py @@ -5,14 +5,7 @@ from toolkit.config_modules import AdapterConfig import torch import torch.nn as nn from typing import TYPE_CHECKING, List, Dict, Any -from toolkit.models.clip_fusion import ZipperBlock -from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler -import sys -from toolkit.paths import REPOS_ROOT - -sys.path.append(REPOS_ROOT) -from ipadapter.ip_adapter.resampler import Resampler -from collections import OrderedDict +from toolkit.resampler import Resampler if TYPE_CHECKING: from toolkit.lora_special import LoRAModule diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index cc7679aa..37396dd7 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -5,23 +5,15 @@ import torch.nn as nn import torch.nn.functional as F import weakref from typing import Union, TYPE_CHECKING - - -from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection -from diffusers.models.embeddings import PixArtAlphaTextProjection - +from transformers import T5EncoderModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection from toolkit import train_tools -from toolkit.paths import REPOS_ROOT from toolkit.prompt_utils import PromptEmbeds from diffusers import Transformer2DModel - -sys.path.append(REPOS_ROOT) - -from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 +from toolkit.util.ip_adapter_utils import AttnProcessor2_0 if TYPE_CHECKING: - from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline + from toolkit.stable_diffusion_model import StableDiffusion from toolkit.custom_adapter import CustomAdapter diff --git a/toolkit/models/te_aug_adapter.py b/toolkit/models/te_aug_adapter.py index 02cbbec1..a8e74bea 100644 --- a/toolkit/models/te_aug_adapter.py +++ b/toolkit/models/te_aug_adapter.py @@ -10,12 +10,6 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule -from toolkit.paths import REPOS_ROOT -from toolkit.resampler import Resampler - -sys.path.append(REPOS_ROOT) - -from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 52c38cec..4af8fe88 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -13,8 +13,6 @@ from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionIma from transformers import SiglipImageProcessor, SiglipVisionModel import traceback from toolkit.config_modules import AdapterConfig -from toolkit.paths import REPOS_ROOT -sys.path.append(REPOS_ROOT) if TYPE_CHECKING: diff --git a/toolkit/reference_adapter.py b/toolkit/reference_adapter.py index d00dfb72..ec7b26f3 100644 --- a/toolkit/reference_adapter.py +++ b/toolkit/reference_adapter.py @@ -15,10 +15,6 @@ from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from collections import OrderedDict -from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ - AttnProcessor2_0 -from ipadapter.ip_adapter.ip_adapter import ImageProjModel -from ipadapter.ip_adapter.resampler import Resampler from toolkit.config_modules import AdapterConfig from toolkit.prompt_utils import PromptEmbeds import weakref diff --git a/toolkit/util/ip_adapter_utils.py b/toolkit/util/ip_adapter_utils.py new file mode 100644 index 00000000..8e80643a --- /dev/null +++ b/toolkit/util/ip_adapter_utils.py @@ -0,0 +1,634 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + self.attn_map = ip_attention_probs + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear( + cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + ip_value = ip_value.view( + batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + with torch.no_grad(): + self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) + # print(self.attn_map.shape) + + ip_hidden_states = ip_hidden_states.transpose( + 1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +# for controlnet +class CNAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, num_tokens=4): + self.num_tokens = num_tokens + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + # only use text + encoder_hidden_states = encoder_hidden_states[:, :end_pos] + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CNAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, num_tokens=4): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.num_tokens = num_tokens + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + # only use text + encoder_hidden_states = encoder_hidden_states[:, :end_pos] + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, + head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + super().__init__() + + self.generator = None + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear( + clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens