Remove ip adapter submodule

This commit is contained in:
Jaret Burkett
2025-04-18 09:59:42 -06:00
parent c90615f8bb
commit 5f312cd46b
13 changed files with 709 additions and 70 deletions

4
.gitmodules vendored
View File

@@ -10,7 +10,3 @@
path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
[submodule "repositories/ipadapter"]
path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14

View File

@@ -2,10 +2,14 @@
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
# This will only have the transformer weights, not the TEs and VAE
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
# You can also save with scaled 8-bit using the --do_8bit_scaled flag
#
# Call like this for 8-bit transformer weights:
# Call like this for 8-bit transformer weights with stochastic rounding:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
#
# Call like this for 8-bit transformer weights with scaling:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled
#
# Call like this for bf16 transformer weights:
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
#
@@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str,
parser.add_argument("flux_path", type=str,
help="Output path for the Flux safetensors file.")
parser.add_argument("--do_8_bit", action="store_true",
help="Use 8-bit weights instead of bf16.")
help="Use 8-bit weights with stochastic rounding instead of bf16.")
parser.add_argument("--do_8bit_scaled", action="store_true",
help="Use scaled 8-bit weights instead of bf16.")
args = parser.parse_args()
flux_path = Path(args.flux_path)
@@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")):
diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
do_8_bit = args.do_8_bit
do_8bit_scaled = args.do_8bit_scaled
# Don't allow both flags to be active simultaneously
if do_8_bit and do_8bit_scaled:
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.")
exit()
if not os.path.exists(flux_path.parent):
os.makedirs(flux_path.parent)
@@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
return rounded.to(dtype)
# set all the keys to bf16
# List of keys that should not be scaled (usually embedding layers and biases)
blacklist = []
for key in flux.keys():
if do_8_bit:
if not key.endswith(".weight") or "embed" in key:
blacklist.append(key)
# Function to scale weights for 8-bit quantization
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn):
# Get the limits of the dtype
min_val = torch.finfo(dtype).min
max_val = torch.finfo(dtype).max
# Only process 2D tensors that are not in the blacklist
if tensor.dim() == 2:
# Calculate the scaling factor
abs_max = torch.max(torch.abs(tensor))
scale = abs_max / max_value
# Scale the tensor and clip to float8 range
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype)
return scaled_tensor, scale
else:
# For tensors that shouldn't be scaled, just convert to float8
return tensor.clip(min=min_val, max=max_val).to(dtype), None
# set all the keys to appropriate dtype
if do_8_bit:
print("Converting to 8-bit with stochastic rounding...")
for key in flux.keys():
flux[key] = stochastic_round_to(
flux[key], torch.float8_e4m3fn).to('cpu')
else:
elif do_8bit_scaled:
print("Converting to scaled 8-bit...")
scales = {}
for key in tqdm.tqdm(flux.keys()):
if key.endswith(".weight") and key not in blacklist:
flux[key], scale = scale_weights_to_8bit(flux[key])
if scale is not None:
scale_key = key[:-len(".weight")] + ".scale_weight"
scales[scale_key] = scale
else:
# For non-weight tensors or blacklisted ones, just convert without scaling
min_val = torch.finfo(torch.float8_e4m3fn).min
max_val = torch.finfo(torch.float8_e4m3fn).max
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu')
# Add all the scales to the flux dictionary
flux.update(scales)
# Add a marker tensor to indicate this is a scaled fp8 model
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn)
else:
print("Converting to bfloat16...")
for key in flux.keys():
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
meta = OrderedDict()
meta['format'] = 'pt'
# date format like 2024-08-01 YYYY-MM-DD
@@ -394,4 +454,4 @@ print(f"Saving to {flux_path}")
safetensors.torch.save_file(flux, flux_path, metadata=meta)
print("Done.")
print("Done.")

View File

@@ -31,10 +31,6 @@ from toolkit.util.mask import generate_random_mask
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref

View File

@@ -3,7 +3,6 @@ import random
import torch
import sys
from PIL import Image
from diffusers import Transformer2DModel
from torch import nn
from torch.nn import Parameter
@@ -12,18 +11,14 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.zipper_resampler import ZipperResampler
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
from toolkit.util.inverse_cfg import inverse_classifier_guidance
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel
from toolkit.resampler import Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref
@@ -35,9 +30,7 @@ if TYPE_CHECKING:
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
CLIPVisionModel,
AutoImageProcessor,
ConvNextModel,
ConvNextV2ForImageClassification,
ConvNextForImageClassification,
ConvNextImageProcessor
@@ -48,9 +41,6 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
from transformers import ViTFeatureExtractor, ViTForImageClassification
# gradient checkpointing
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

View File

@@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
if TYPE_CHECKING:

View File

@@ -4,13 +4,7 @@ import weakref
import torch
import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
from toolkit.resampler import Resampler
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule

View File

@@ -5,14 +5,7 @@ from toolkit.config_modules import AdapterConfig
import torch
import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
from toolkit.resampler import Resampler
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule

View File

@@ -5,23 +5,15 @@ import torch.nn as nn
import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from diffusers.models.embeddings import PixArtAlphaTextProjection
from transformers import T5EncoderModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
from toolkit.prompt_utils import PromptEmbeds
from diffusers import Transformer2DModel
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
from toolkit.util.ip_adapter_utils import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter

View File

@@ -10,12 +10,6 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
from toolkit.paths import REPOS_ROOT
from toolkit.resampler import Resampler
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion

View File

@@ -13,8 +13,6 @@ from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionIma
from transformers import SiglipImageProcessor, SiglipVisionModel
import traceback
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
if TYPE_CHECKING:

View File

@@ -15,10 +15,6 @@ from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref

View File

@@ -0,0 +1,634 @@
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
self.attn_map = ip_attention_probs
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
The context length of the image features.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.num_tokens = num_tokens
self.to_k_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(
cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
ip_value = ip_value.view(
batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
with torch.no_grad():
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
# print(self.attn_map.shape)
ip_hidden_states = ip_hidden_states.transpose(
1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# for controlnet
class CNAttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __init__(self, num_tokens=4):
self.num_tokens = num_tokens
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
# only use text
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CNAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self, num_tokens=4):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.num_tokens = num_tokens
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(
hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
else:
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
# only use text
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads,
head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(
-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.generator = None
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens