mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Remove ip adapter submodule
This commit is contained in:
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -10,7 +10,3 @@
|
|||||||
path = repositories/batch_annotator
|
path = repositories/batch_annotator
|
||||||
url = https://github.com/ostris/batch-annotator
|
url = https://github.com/ostris/batch-annotator
|
||||||
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
|
commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf
|
||||||
[submodule "repositories/ipadapter"]
|
|
||||||
path = repositories/ipadapter
|
|
||||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
|
||||||
commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14
|
|
||||||
|
|||||||
Submodule repositories/ipadapter deleted from 5a18b1f366
@@ -2,10 +2,14 @@
|
|||||||
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
|
# Convert Diffusers Flux/Flex to diffusion model ComfyUI safetensors file
|
||||||
# This will only have the transformer weights, not the TEs and VAE
|
# This will only have the transformer weights, not the TEs and VAE
|
||||||
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
|
# You can save the transformer weights as bf16 or 8-bit with the --do_8_bit flag
|
||||||
|
# You can also save with scaled 8-bit using the --do_8bit_scaled flag
|
||||||
#
|
#
|
||||||
# Call like this for 8-bit transformer weights:
|
# Call like this for 8-bit transformer weights with stochastic rounding:
|
||||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
|
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8_bit
|
||||||
#
|
#
|
||||||
|
# Call like this for 8-bit transformer weights with scaling:
|
||||||
|
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors --do_8bit_scaled
|
||||||
|
#
|
||||||
# Call like this for bf16 transformer weights:
|
# Call like this for bf16 transformer weights:
|
||||||
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
|
# python convert_diffusers_to_comfy_transformer_only.py /path/to/diffusers/checkpoint /output/path/my_finetune.safetensors
|
||||||
#
|
#
|
||||||
@@ -33,7 +37,9 @@ parser.add_argument("diffusers_path", type=str,
|
|||||||
parser.add_argument("flux_path", type=str,
|
parser.add_argument("flux_path", type=str,
|
||||||
help="Output path for the Flux safetensors file.")
|
help="Output path for the Flux safetensors file.")
|
||||||
parser.add_argument("--do_8_bit", action="store_true",
|
parser.add_argument("--do_8_bit", action="store_true",
|
||||||
help="Use 8-bit weights instead of bf16.")
|
help="Use 8-bit weights with stochastic rounding instead of bf16.")
|
||||||
|
parser.add_argument("--do_8bit_scaled", action="store_true",
|
||||||
|
help="Use scaled 8-bit weights instead of bf16.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
flux_path = Path(args.flux_path)
|
flux_path = Path(args.flux_path)
|
||||||
@@ -43,6 +49,12 @@ if os.path.exists(os.path.join(diffusers_path, "transformer")):
|
|||||||
diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
|
diffusers_path = Path(os.path.join(diffusers_path, "transformer"))
|
||||||
|
|
||||||
do_8_bit = args.do_8_bit
|
do_8_bit = args.do_8_bit
|
||||||
|
do_8bit_scaled = args.do_8bit_scaled
|
||||||
|
|
||||||
|
# Don't allow both flags to be active simultaneously
|
||||||
|
if do_8_bit and do_8bit_scaled:
|
||||||
|
print("Error: Cannot use both --do_8_bit and --do_8bit_scaled at the same time.")
|
||||||
|
exit()
|
||||||
|
|
||||||
if not os.path.exists(flux_path.parent):
|
if not os.path.exists(flux_path.parent):
|
||||||
os.makedirs(flux_path.parent)
|
os.makedirs(flux_path.parent)
|
||||||
@@ -373,16 +385,64 @@ def stochastic_round_to(tensor, dtype=torch.float8_e4m3fn):
|
|||||||
return rounded.to(dtype)
|
return rounded.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
# set all the keys to bf16
|
# List of keys that should not be scaled (usually embedding layers and biases)
|
||||||
|
blacklist = []
|
||||||
for key in flux.keys():
|
for key in flux.keys():
|
||||||
if do_8_bit:
|
if not key.endswith(".weight") or "embed" in key:
|
||||||
|
blacklist.append(key)
|
||||||
|
|
||||||
|
# Function to scale weights for 8-bit quantization
|
||||||
|
def scale_weights_to_8bit(tensor, max_value=416.0, dtype=torch.float8_e4m3fn):
|
||||||
|
# Get the limits of the dtype
|
||||||
|
min_val = torch.finfo(dtype).min
|
||||||
|
max_val = torch.finfo(dtype).max
|
||||||
|
|
||||||
|
# Only process 2D tensors that are not in the blacklist
|
||||||
|
if tensor.dim() == 2:
|
||||||
|
# Calculate the scaling factor
|
||||||
|
abs_max = torch.max(torch.abs(tensor))
|
||||||
|
scale = abs_max / max_value
|
||||||
|
|
||||||
|
# Scale the tensor and clip to float8 range
|
||||||
|
scaled_tensor = (tensor / scale).clip(min=min_val, max=max_val).to(dtype)
|
||||||
|
|
||||||
|
return scaled_tensor, scale
|
||||||
|
else:
|
||||||
|
# For tensors that shouldn't be scaled, just convert to float8
|
||||||
|
return tensor.clip(min=min_val, max=max_val).to(dtype), None
|
||||||
|
|
||||||
|
|
||||||
|
# set all the keys to appropriate dtype
|
||||||
|
if do_8_bit:
|
||||||
|
print("Converting to 8-bit with stochastic rounding...")
|
||||||
|
for key in flux.keys():
|
||||||
flux[key] = stochastic_round_to(
|
flux[key] = stochastic_round_to(
|
||||||
flux[key], torch.float8_e4m3fn).to('cpu')
|
flux[key], torch.float8_e4m3fn).to('cpu')
|
||||||
else:
|
elif do_8bit_scaled:
|
||||||
|
print("Converting to scaled 8-bit...")
|
||||||
|
scales = {}
|
||||||
|
for key in tqdm.tqdm(flux.keys()):
|
||||||
|
if key.endswith(".weight") and key not in blacklist:
|
||||||
|
flux[key], scale = scale_weights_to_8bit(flux[key])
|
||||||
|
if scale is not None:
|
||||||
|
scale_key = key[:-len(".weight")] + ".scale_weight"
|
||||||
|
scales[scale_key] = scale
|
||||||
|
else:
|
||||||
|
# For non-weight tensors or blacklisted ones, just convert without scaling
|
||||||
|
min_val = torch.finfo(torch.float8_e4m3fn).min
|
||||||
|
max_val = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
flux[key] = flux[key].clip(min=min_val, max=max_val).to(torch.float8_e4m3fn).to('cpu')
|
||||||
|
|
||||||
|
# Add all the scales to the flux dictionary
|
||||||
|
flux.update(scales)
|
||||||
|
|
||||||
|
# Add a marker tensor to indicate this is a scaled fp8 model
|
||||||
|
flux["scaled_fp8"] = torch.tensor([]).to(torch.float8_e4m3fn)
|
||||||
|
else:
|
||||||
|
print("Converting to bfloat16...")
|
||||||
|
for key in flux.keys():
|
||||||
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
|
flux[key] = flux[key].clone().to('cpu', torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
meta = OrderedDict()
|
meta = OrderedDict()
|
||||||
meta['format'] = 'pt'
|
meta['format'] = 'pt'
|
||||||
# date format like 2024-08-01 YYYY-MM-DD
|
# date format like 2024-08-01 YYYY-MM-DD
|
||||||
@@ -394,4 +454,4 @@ print(f"Saving to {flux_path}")
|
|||||||
|
|
||||||
safetensors.torch.save_file(flux, flux_path, metadata=meta)
|
safetensors.torch.save_file(flux, flux_path, metadata=meta)
|
||||||
|
|
||||||
print("Done.")
|
print("Done.")
|
||||||
@@ -31,10 +31,6 @@ from toolkit.util.mask import generate_random_mask
|
|||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
|
||||||
AttnProcessor2_0
|
|
||||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
|
||||||
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
from toolkit.config_modules import AdapterConfig, AdapterTypes, TrainConfig
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
import weakref
|
import weakref
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from diffusers import Transformer2DModel
|
from diffusers import Transformer2DModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
@@ -12,18 +11,14 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|||||||
|
|
||||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||||
from toolkit.models.zipper_resampler import ZipperResampler
|
from toolkit.models.zipper_resampler import ZipperResampler
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from toolkit.saving import load_ip_adapter_model
|
from toolkit.saving import load_ip_adapter_model
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
from toolkit.util.inverse_cfg import inverse_classifier_guidance
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
from toolkit.util.ip_adapter_utils import AttnProcessor2_0, IPAttnProcessor2_0, ImageProjModel
|
||||||
AttnProcessor2_0
|
from toolkit.resampler import Resampler
|
||||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
|
||||||
from ipadapter.ip_adapter.resampler import PerceiverAttention, FeedForward, Resampler
|
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
import weakref
|
import weakref
|
||||||
@@ -35,9 +30,7 @@ if TYPE_CHECKING:
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CLIPImageProcessor,
|
CLIPImageProcessor,
|
||||||
CLIPVisionModelWithProjection,
|
CLIPVisionModelWithProjection,
|
||||||
CLIPVisionModel,
|
|
||||||
AutoImageProcessor,
|
AutoImageProcessor,
|
||||||
ConvNextModel,
|
|
||||||
ConvNextV2ForImageClassification,
|
ConvNextV2ForImageClassification,
|
||||||
ConvNextForImageClassification,
|
ConvNextForImageClassification,
|
||||||
ConvNextImageProcessor
|
ConvNextImageProcessor
|
||||||
@@ -48,9 +41,6 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
|
|||||||
|
|
||||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||||
|
|
||||||
# gradient checkpointing
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, List, Dict, Any
|
|||||||
from toolkit.models.clip_fusion import ZipperBlock
|
from toolkit.models.clip_fusion import ZipperBlock
|
||||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||||
import sys
|
import sys
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -4,13 +4,7 @@ import weakref
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import TYPE_CHECKING, List, Dict, Any
|
from typing import TYPE_CHECKING, List, Dict, Any
|
||||||
from toolkit.models.clip_fusion import ZipperBlock
|
from toolkit.resampler import Resampler
|
||||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
|
||||||
import sys
|
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.lora_special import LoRAModule
|
from toolkit.lora_special import LoRAModule
|
||||||
|
|||||||
@@ -5,14 +5,7 @@ from toolkit.config_modules import AdapterConfig
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import TYPE_CHECKING, List, Dict, Any
|
from typing import TYPE_CHECKING, List, Dict, Any
|
||||||
from toolkit.models.clip_fusion import ZipperBlock
|
from toolkit.resampler import Resampler
|
||||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
|
||||||
import sys
|
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.lora_special import LoRAModule
|
from toolkit.lora_special import LoRAModule
|
||||||
|
|||||||
@@ -5,23 +5,15 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Union, TYPE_CHECKING
|
from typing import Union, TYPE_CHECKING
|
||||||
|
from transformers import T5EncoderModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||||
|
|
||||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
|
||||||
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
|
||||||
|
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from diffusers import Transformer2DModel
|
from diffusers import Transformer2DModel
|
||||||
|
from toolkit.util.ip_adapter_utils import AttnProcessor2_0
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
|
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion, PixArtSigmaPipeline
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
from toolkit.custom_adapter import CustomAdapter
|
from toolkit.custom_adapter import CustomAdapter
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,12 +10,6 @@ from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokeniz
|
|||||||
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
|
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
|
||||||
|
|
||||||
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
|
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
from toolkit.resampler import Resampler
|
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
|
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionIma
|
|||||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||||
import traceback
|
import traceback
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -15,10 +15,6 @@ from toolkit.train_tools import get_torch_dtype
|
|||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
|
||||||
AttnProcessor2_0
|
|
||||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
import weakref
|
import weakref
|
||||||
|
|||||||
634
toolkit/util/ip_adapter_utils.py
Normal file
634
toolkit/util/ip_adapter_utils.py
Normal file
@@ -0,0 +1,634 @@
|
|||||||
|
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessor(nn.Module):
|
||||||
|
r"""
|
||||||
|
Default processor for performing attention-related computations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=None,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(
|
||||||
|
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||||
|
encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(
|
||||||
|
-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class IPAttnProcessor(nn.Module):
|
||||||
|
r"""
|
||||||
|
Attention processor for IP-Adapater.
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`):
|
||||||
|
The hidden size of the attention layer.
|
||||||
|
cross_attention_dim (`int`):
|
||||||
|
The number of channels in the `encoder_hidden_states`.
|
||||||
|
scale (`float`, defaults to 1.0):
|
||||||
|
the weight scale of image prompt.
|
||||||
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||||
|
The context length of the image features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.scale = scale
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
|
self.to_k_ip = nn.Linear(
|
||||||
|
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
self.to_v_ip = nn.Linear(
|
||||||
|
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(
|
||||||
|
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
# get encoder_hidden_states, ip_hidden_states
|
||||||
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
|
encoder_hidden_states, ip_hidden_states = (
|
||||||
|
encoder_hidden_states[:, :end_pos, :],
|
||||||
|
encoder_hidden_states[:, end_pos:, :],
|
||||||
|
)
|
||||||
|
if attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||||
|
encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# for ip-adapter
|
||||||
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
ip_key = attn.head_to_batch_dim(ip_key)
|
||||||
|
ip_value = attn.head_to_batch_dim(ip_value)
|
||||||
|
|
||||||
|
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||||
|
self.attn_map = ip_attention_probs
|
||||||
|
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||||
|
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(
|
||||||
|
-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class AttnProcessor2_0(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size=None,
|
||||||
|
cross_attention_dim=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(
|
||||||
|
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||||
|
encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(
|
||||||
|
-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class IPAttnProcessor2_0(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||||
|
Args:
|
||||||
|
hidden_size (`int`):
|
||||||
|
The hidden size of the attention layer.
|
||||||
|
cross_attention_dim (`int`):
|
||||||
|
The number of channels in the `encoder_hidden_states`.
|
||||||
|
scale (`float`, defaults to 1.0):
|
||||||
|
the weight scale of image prompt.
|
||||||
|
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
||||||
|
The context length of the image features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.scale = scale
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
|
self.to_k_ip = nn.Linear(
|
||||||
|
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
self.to_v_ip = nn.Linear(
|
||||||
|
cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(
|
||||||
|
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
# get encoder_hidden_states, ip_hidden_states
|
||||||
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
|
encoder_hidden_states, ip_hidden_states = (
|
||||||
|
encoder_hidden_states[:, :end_pos, :],
|
||||||
|
encoder_hidden_states[:, end_pos:, :],
|
||||||
|
)
|
||||||
|
if attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||||
|
encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# for ip-adapter
|
||||||
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
with torch.no_grad():
|
||||||
|
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
|
||||||
|
# print(self.attn_map.shape)
|
||||||
|
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(
|
||||||
|
1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(
|
||||||
|
-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# for controlnet
|
||||||
|
class CNAttnProcessor:
|
||||||
|
r"""
|
||||||
|
Default processor for performing attention-related computations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_tokens=4):
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
|
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(
|
||||||
|
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
|
# only use text
|
||||||
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
|
||||||
|
if attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||||
|
encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
key = attn.head_to_batch_dim(key)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||||
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(
|
||||||
|
-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class CNAttnProcessor2_0:
|
||||||
|
r"""
|
||||||
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_tokens=4):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
self.num_tokens = num_tokens
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(
|
||||||
|
attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(
|
||||||
|
hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
else:
|
||||||
|
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
||||||
|
# only use text
|
||||||
|
encoder_hidden_states = encoder_hidden_states[:, :end_pos]
|
||||||
|
if attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
||||||
|
encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states)
|
||||||
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads,
|
||||||
|
head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(
|
||||||
|
-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProjModel(torch.nn.Module):
|
||||||
|
"""Projection Model"""
|
||||||
|
|
||||||
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.generator = None
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||||
|
self.proj = torch.nn.Linear(
|
||||||
|
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||||
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
|
def forward(self, image_embeds):
|
||||||
|
embeds = image_embeds
|
||||||
|
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||||
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
||||||
|
)
|
||||||
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||||
|
return clip_extra_context_tokens
|
||||||
Reference in New Issue
Block a user