Remove ip adapter submodule

This commit is contained in:
Jaret Burkett
2025-04-18 09:59:42 -06:00
parent c90615f8bb
commit 5f312cd46b
13 changed files with 709 additions and 70 deletions

4
.gitmodules vendored
View File

@@ -10,7 +10,3 @@
path = repositories/batch_annotator path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator url = https://github.com/ostris/batch-annotator
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
[submodule "repositories/ipadapter"]
path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14

View File

@@ -2,10 +2,14 @@
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file # Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
# This will only have the transformer weights, not the TEs and VAE # This will only have the transformer weights, not the TEs and VAE
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag # You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
# You can also save with scaled 8-bit using the --do_8bit_scaled flag
# #
# Call like this for 8-bit transformer weights: # Call like this for 8-bit transformer weights with stochastic rounding:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit # python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
# #
# Call like this for 8-bit transformer weights with scaling:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled
#
# Call like this for bf16 transformer weights: # Call like this for bf16 transformer weights:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors # python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
# #
@@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str,
parser.add_argument("flux_path", type=str, parser.add_argument("flux_path", type=str,
help="Output path for the Flux safetensors file.") help="Output path for the Flux safetensors file.")
parser.add_argument("--do_8_bit", action="store_true", parser.add_argument("--do_8_bit", action="store_true",
help="Use 8-bit weights instead of bf16.") help="Use 8-bit weights with stochastic rounding instead of bf16.")
parser.add_argument("--do_8bit_scaled", action="store_true",
help="Use scaled 8-bit weights instead of bf16.")
args = parser.parse_args() args = parser.parse_args()
flux_path = Path(args.flux_path) flux_path = Path(args.flux_path)
@@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")):
diffusers_path = Path(os.path.join(diffusers_path, "transformer")) diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
do_8_bit = args.do_8_bit do_8_bit = args.do_8_bit
do_8bit_scaled = args.do_8bit_scaled
# Don't allow both flags to be active simultaneously
if do_8_bit and do_8bit_scaled:
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.")
exit()
if not os.path.exists(flux_path.parent): if not os.path.exists(flux_path.parent):
os.makedirs(flux_path.parent) os.makedirs(flux_path.parent)
@@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
return rounded.to(dtype) return rounded.to(dtype)
# set all the keys to bf16 # List of keys that should not be scaled (usually embedding layers and biases)
blacklist = []
for key in flux.keys(): for key in flux.keys():
if do_8_bit: if not key.endswith(".weight") or "embed" in key:
blacklist.append(key)
# Function to scale weights for 8-bit quantization
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn):
# Get the limits of the dtype
min_val = torch.finfo(dtype).min
max_val = torch.finfo(dtype).max
# Only process 2D tensors that are not in the blacklist
if tensor.dim() == 2:
# Calculate the scaling factor
abs_max = torch.max(torch.abs(tensor))
scale = abs_max / max_value
# Scale the tensor and clip to float8 range
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype)
return scaled_tensor, scale
else:
# For tensors that shouldn't be scaled, just convert to float8
return tensor.clip(min=min_val, max=max_val).to(dtype), None
# set all the keys to appropriate dtype
if do_8_bit:
print("Converting to 8-bit with stochastic rounding...")
for key in flux.keys():
flux[key] = stochastic_round_to( flux[key] = stochastic_round_to(
flux[key], torch.float8_e4m3fn).to('cpu') flux[key], torch.float8_e4m3fn).to('cpu')
else: elif do_8bit_scaled:
print("Converting to scaled 8-bit...")
scales = {}
for key in tqdm.tqdm(flux.keys()):
if key.endswith(".weight") and key not in blacklist:
flux[key], scale = scale_weights_to_8bit(flux[key])
if scale is not None:
scale_key = key[:-len(".weight")] + ".scale_weight"
scales[scale_key] = scale
else:
# For non-weight tensors or blacklisted ones, just convert without scaling
min_val = torch.finfo(torch.float8_e4m3fn).min
max_val = torch.finfo(torch.float8_e4m3fn).max
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu')
# Add all the scales to the flux dictionary
flux.update(scales)
# Add a marker tensor to indicate this is a scaled fp8 model
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn)
else:
print("Converting to bfloat16...")
for key in flux.keys():
flux[key] = flux[key].clone().to('cpu', torch.bfloat16) flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
meta = OrderedDict() meta = OrderedDict()
meta['format'] = 'pt' meta['format'] = 'pt'
# date format like 2024-08-01 YYYY-MM-DD # date format like 2024-08-01 YYYY-MM-DD
@@ -394,4 +454,4 @@ print(f"Saving to {flux_path}")
safetensors.torch.save_file(flux, flux_path, metadata=meta) safetensors.torch.save_file(flux, flux_path, metadata=meta)
print("Done.") print("Done.")

View File

@@ -31,10 +31,6 @@ from toolkit.util.mask import generate_random_mask
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
import weakref import weakref

View File

@@ -3,7 +3,6 @@ import random
import torch import torch
import sys import sys
from PIL import Image
from diffusers import Transformer2DModel from diffusers import Transformer2DModel
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
@@ -12,18 +11,14 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.zipper_resampler import ZipperResampler from toolkit.models.zipper_resampler import ZipperResampler
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
from toolkit.util.inverse_cfg import inverse_classifier_guidance from toolkit.util.inverse_cfg import inverse_classifier_guidance
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from collections import OrderedDict from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \ from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel
AttnProcessor2_0 from toolkit.resampler import Resampler
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
from toolkit.config_modules import AdapterConfig from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
import weakref import weakref
@@ -35,9 +30,7 @@ if TYPE_CHECKING:
from transformers import ( from transformers import (
CLIPImageProcessor, CLIPImageProcessor,
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
CLIPVisionModel,
AutoImageProcessor, AutoImageProcessor,
ConvNextModel,
ConvNextV2ForImageClassification, ConvNextV2ForImageClassification,
ConvNextForImageClassification, ConvNextForImageClassification,
ConvNextImageProcessor ConvNextImageProcessor
@@ -48,9 +41,6 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
from transformers import ViTFeatureExtractor, ViTForImageClassification from transformers import ViTFeatureExtractor, ViTForImageClassification
# gradient checkpointing
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F import torch.nn.functional as F

View File

@@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict from collections import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -4,13 +4,7 @@ import weakref
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock from toolkit.resampler import Resampler
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule from toolkit.lora_special import LoRAModule

View File

@@ -5,14 +5,7 @@ from toolkit.config_modules import AdapterConfig
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock from toolkit.resampler import Resampler
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule from toolkit.lora_special import LoRAModule

View File

@@ -5,23 +5,15 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import weakref import weakref
from typing import Union, TYPE_CHECKING from typing import Union, TYPE_CHECKING
from transformers import T5EncoderModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from diffusers.models.embeddings import PixArtAlphaTextProjection
from toolkit import train_tools from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel from diffusers import Transformer2DModel
from toolkit.util.ip_adapter_utils import AttnProcessor2_0
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter from toolkit.custom_adapter import CustomAdapter

View File

@@ -10,12 +10,6 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
from toolkit.paths import REPOS_ROOT
from toolkit.resampler import Resampler
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion from toolkit.stable_diffusion_model import StableDiffusion

View File

@@ -13,8 +13,6 @@ from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionIma
from transformers import SiglipImageProcessor, SiglipVisionModel from transformers import SiglipImageProcessor, SiglipVisionModel
import traceback import traceback
from toolkit.config_modules import AdapterConfig from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -15,10 +15,6 @@ from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
import weakref import weakref

View File

@@ -0,0 +1,634 @@
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
ip_value = ip_value.view(
batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
# print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(
1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# for controlnet
class CNAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, num_tokens=4):
self.num_tokens = num_tokens
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
# only use text
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CNAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, num_tokens=4):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.num_tokens = num_tokens
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
# only use text
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens