mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Remove ip adapter submodule
This commit is contained in:
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -10,7 +10,3 @@
|
||||
path = repositories/batch_annotator
|
||||
url = https://github.com/ostris/batch-annotator
|
||||
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
|
||||
[submodule "repositories/ipadapter"]
|
||||
path = repositories/ipadapter
|
||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
||||
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14
|
||||
|
||||
Submodule repositories/ipadapter deleted from 5a18b1f366
@@ -2,10 +2,14 @@
|
||||
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
|
||||
# This will only have the transformer weights, not the TEs and VAE
|
||||
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
|
||||
# You can also save with scaled 8-bit using the --do_8bit_scaled flag
|
||||
#
|
||||
# Call like this for 8-bit transformer weights:
|
||||
# Call like this for 8-bit transformer weights with stochastic rounding:
|
||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
|
||||
#
|
||||
# Call like this for 8-bit transformer weights with scaling:
|
||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled
|
||||
#
|
||||
# Call like this for bf16 transformer weights:
|
||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
|
||||
#
|
||||
@@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str,
|
||||
parser.add_argument("flux_path", type=str,
|
||||
help="Output path for the Flux safetensors file.")
|
||||
parser.add_argument("--do_8_bit", action="store_true",
|
||||
help="Use 8-bit weights instead of bf16.")
|
||||
help="Use 8-bit weights with stochastic rounding instead of bf16.")
|
||||
parser.add_argument("--do_8bit_scaled", action="store_true",
|
||||
help="Use scaled 8-bit weights instead of bf16.")
|
||||
args = parser.parse_args()
|
||||
|
||||
flux_path = Path(args.flux_path)
|
||||
@@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")):
|
||||
diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
|
||||
|
||||
do_8_bit = args.do_8_bit
|
||||
do_8bit_scaled = args.do_8bit_scaled
|
||||
|
||||
# Don't allow both flags to be active simultaneously
|
||||
if do_8_bit and do_8bit_scaled:
|
||||
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.")
|
||||
exit()
|
||||
|
||||
if not os.path.exists(flux_path.parent):
|
||||
os.makedirs(flux_path.parent)
|
||||
@@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
|
||||
return rounded.to(dtype)
|
||||
|
||||
|
||||
# set all the keys to bf16
|
||||
# List of keys that should not be scaled (usually embedding layers and biases)
|
||||
blacklist = []
|
||||
for key in flux.keys():
|
||||
if do_8_bit:
|
||||
if not key.endswith(".weight") or "embed" in key:
|
||||
blacklist.append(key)
|
||||
|
||||
# Function to scale weights for 8-bit quantization
|
||||
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn):
|
||||
# Get the limits of the dtype
|
||||
min_val = torch.finfo(dtype).min
|
||||
max_val = torch.finfo(dtype).max
|
||||
|
||||
# Only process 2D tensors that are not in the blacklist
|
||||
if tensor.dim() == 2:
|
||||
# Calculate the scaling factor
|
||||
abs_max = torch.max(torch.abs(tensor))
|
||||
scale = abs_max / max_value
|
||||
|
||||
# Scale the tensor and clip to float8 range
|
||||
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype)
|
||||
|
||||
return scaled_tensor, scale
|
||||
else:
|
||||
# For tensors that shouldn't be scaled, just convert to float8
|
||||
return tensor.clip(min=min_val, max=max_val).to(dtype), None
|
||||
|
||||
|
||||
# set all the keys to appropriate dtype
|
||||
if do_8_bit:
|
||||
print("Converting to 8-bit with stochastic rounding...")
|
||||
for key in flux.keys():
|
||||
flux[key] = stochastic_round_to(
|
||||
flux[key], torch.float8_e4m3fn).to('cpu')
|
||||
else:
|
||||
elif do_8bit_scaled:
|
||||
print("Converting to scaled 8-bit...")
|
||||
scales = {}
|
||||
for key in tqdm.tqdm(flux.keys()):
|
||||
if key.endswith(".weight") and key not in blacklist:
|
||||
flux[key], scale = scale_weights_to_8bit(flux[key])
|
||||
if scale is not None:
|
||||
scale_key = key[:-len(".weight")] + ".scale_weight"
|
||||
scales[scale_key] = scale
|
||||
else:
|
||||
# For non-weight tensors or blacklisted ones, just convert without scaling
|
||||
min_val = torch.finfo(torch.float8_e4m3fn).min
|
||||
max_val = torch.finfo(torch.float8_e4m3fn).max
|
||||
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu')
|
||||
|
||||
# Add all the scales to the flux dictionary
|
||||
flux.update(scales)
|
||||
|
||||
# Add a marker tensor to indicate this is a scaled fp8 model
|
||||
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn)
|
||||
else:
|
||||
print("Converting to bfloat16...")
|
||||
for key in flux.keys():
|
||||
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
|
||||
|
||||
|
||||
|
||||
meta = OrderedDict()
|
||||
meta['format'] = 'pt'
|
||||
# date format like 2024-08-01 YYYY-MM-DD
|
||||
@@ -394,4 +454,4 @@ print(f"Saving to {flux_path}")
|
||||
|
||||
safetensors.torch.save_file(flux, flux_path, metadata=meta)
|
||||
|
||||
print("Done.")
|
||||
print("Done.")
|
||||
@@ -31,10 +31,6 @@ from toolkit.util.mask import generate_random_mask
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
@@ -3,7 +3,6 @@ import random
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import Transformer2DModel
|
||||
from torch import nn
|
||||
from torch.nn import Parameter
|
||||
@@ -12,18 +11,14 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.zipper_resampler import ZipperResampler
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
|
||||
from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel
|
||||
from toolkit.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
@@ -35,9 +30,7 @@ if TYPE_CHECKING:
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionModel,
|
||||
AutoImageProcessor,
|
||||
ConvNextModel,
|
||||
ConvNextV2ForImageClassification,
|
||||
ConvNextForImageClassification,
|
||||
ConvNextImageProcessor
|
||||
@@ -48,9 +41,6 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
# gradient checkpointing
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, List, Dict, Any
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -4,13 +4,7 @@ import weakref
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, List, Dict, Any
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from collections import OrderedDict
|
||||
from toolkit.resampler import Resampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
|
||||
@@ -5,14 +5,7 @@ from toolkit.config_modules import AdapterConfig
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, List, Dict, Any
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from collections import OrderedDict
|
||||
from toolkit.resampler import Resampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
|
||||
@@ -5,23 +5,15 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from diffusers import Transformer2DModel
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
from toolkit.util.ip_adapter_utils import AttnProcessor2_0
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
|
||||
@@ -10,12 +10,6 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
|
||||
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
|
||||
|
||||
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.resampler import Resampler
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
@@ -13,8 +13,6 @@ from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionIma
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
import traceback
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -15,10 +15,6 @@ from toolkit.train_tools import get_torch_dtype
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
634
toolkit/util/ip_adapter_utils.py
Normal file
634
toolkit/util/ip_adapter_utils.py
Normal file
@@ -0,0 +1,634 @@
|
||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class AttnProcessor(nn.Module):
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(
|
||||
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(
|
||||
-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||
The context length of the image features.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.to_k_ip = nn.Linear(
|
||||
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(
|
||||
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(
|
||||
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
self.attn_map = ip_attention_probs
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(
|
||||
-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(
|
||||
batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(
|
||||
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(
|
||||
-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||
The context length of the image features.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.scale = scale
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.to_k_ip = nn.Linear(
|
||||
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_ip = nn.Linear(
|
||||
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(
|
||||
batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(
|
||||
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
# get encoder_hidden_states, ip_hidden_states
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
encoder_hidden_states[:, end_pos:, :],
|
||||
)
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(
|
||||
batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
with torch.no_grad():
|
||||
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
||||
# print(self.attn_map.shape)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(
|
||||
1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(
|
||||
-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# for controlnet
|
||||
class CNAttnProcessor:
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __init__(self, num_tokens=4):
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(
|
||||
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
# only use text
|
||||
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(
|
||||
-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CNAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self, num_tokens=4):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(
|
||||
batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(
|
||||
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||
# only use text
|
||||
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||
encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads,
|
||||
head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(
|
||||
-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
"""Projection Model"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
self.generator = None
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = torch.nn.Linear(
|
||||
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||
)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
Reference in New Issue
Block a user