WIP to add the caption_proj weight to pixart sigma TE adapter

This commit is contained in:
Jaret Burkett
2024-07-06 13:00:21 -06:00
parent acb06d6ff3
commit cab8a1c7b8
8 changed files with 500 additions and 23 deletions

View File

@@ -0,0 +1,267 @@
import math
import weakref
import torch
import torch.nn as nn
from typing import TYPE_CHECKING, List, Dict, Any
from toolkit.models.clip_fusion import ZipperBlock
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
import sys
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.resampler import Resampler
from collections import OrderedDict
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
class TransformerBlock(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.ReLU(),
nn.Linear(dim_feedforward, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, cross_attn_input):
# Self-attention
attn_output, _ = self.self_attn(x, x, x)
x = self.norm1(x + attn_output)
# Cross-attention
cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input)
x = self.norm2(x + cross_attn_output)
# Feed-forward
ff_output = self.feed_forward(x)
x = self.norm3(x + ff_output)
return x
class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
index: int,
lora_module: 'LoRAModule',
instant_lora_module: 'InstantLoRAModule',
up_shape: list = None,
down_shape: list = None,
):
super(InstantLoRAMidModule, self).__init__()
self.up_shape = up_shape
self.down_shape = down_shape
self.index = index
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.embed = None
def down_forward(self, x, *args, **kwargs):
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
down_size = math.prod(self.down_shape)
down_weight = self.embed[:, :down_size]
batch_size = x.shape[0]
# unconditional
if down_weight.shape[0] * 2 == batch_size:
down_weight = torch.cat([down_weight] * 2, dim=0)
weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
x_chunks = torch.chunk(x, batch_size, dim=0)
x_out = []
for i in range(batch_size):
weight_chunk = weight_chunks[i]
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.down_shape)
# check if is conv or linear
if len(weight_chunk.shape) == 4:
padding = 0
if weight_chunk.shape[-1] == 3:
padding = 1
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x
def up_forward(self, x, *args, **kwargs):
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
up_size = math.prod(self.up_shape)
up_weight = self.embed[:, -up_size:]
batch_size = x.shape[0]
# unconditional
if up_weight.shape[0] * 2 == batch_size:
up_weight = torch.cat([up_weight] * 2, dim=0)
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
x_chunks = torch.chunk(x, batch_size, dim=0)
x_out = []
for i in range(batch_size):
weight_chunk = weight_chunks[i]
x_chunk = x_chunks[i]
# reshape
weight_chunk = weight_chunk.view(self.up_shape)
# check if is conv or linear
if len(weight_chunk.shape) == 4:
padding = 0
if weight_chunk.shape[-1] == 3:
padding = 1
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T
x_out.append(x_chunk)
x = torch.cat(x_out, dim=0)
return x
# Initialize the network
# num_blocks = 8
# d_model = 1024 # Adjust as needed
# nhead = 16 # Adjust as needed
# dim_feedforward = 4096 # Adjust as needed
# latent_dim = 1695744
class LoRAFormer(torch.nn.Module):
def __init__(
self,
num_blocks,
d_model=1024,
nhead=16,
dim_feedforward=4096,
sd: 'StableDiffusion'=None,
):
super(LoRAFormer, self).__init__()
# self.linear = torch.nn.Linear(2, 1)
self.sd_ref = weakref.ref(sd)
self.dim = sd.network.lora_dim
# stores the projection vector. Grabbed by modules
self.img_embeds: List[torch.Tensor] = None
# disable merging in. It is slower on inference
self.sd_ref().network.can_merge_in = False
self.ilora_modules = torch.nn.ModuleList()
lora_modules = self.sd_ref().network.get_all_modules()
output_size = 0
self.embed_lengths = []
self.weight_mapping = []
for idx, lora_module in enumerate(lora_modules):
module_dict = lora_module.state_dict()
down_shape = list(module_dict['lora_down.weight'].shape)
up_shape = list(module_dict['lora_up.weight'].shape)
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
module_size = math.prod(down_shape) + math.prod(up_shape)
output_size += module_size
self.embed_lengths.append(module_size)
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
instant_module = InstantLoRAMidModule(
idx,
lora_module,
self,
up_shape=up_shape,
down_shape=down_shape
)
self.ilora_modules.append(instant_module)
# replace the LoRA forwards
lora_module.lora_down.forward = instant_module.down_forward
lora_module.lora_up.forward = instant_module.up_forward
self.output_size = output_size
self.latent = nn.Parameter(torch.randn(1, output_size))
self.latent_proj = nn.Linear(output_size, d_model)
self.blocks = nn.ModuleList([
TransformerBlock(d_model, nhead, dim_feedforward)
for _ in range(num_blocks)
])
self.final_proj = nn.Linear(d_model, output_size)
self.migrate_weight_mapping()
def migrate_weight_mapping(self):
return
# # changes the names of the modules to common ones
# keymap = self.sd_ref().network.get_keymap()
# save_keymap = {}
# if keymap is not None:
# for ldm_key, diffusers_key in keymap.items():
# # invert them
# save_keymap[diffusers_key] = ldm_key
#
# new_keymap = {}
# for key, value in self.weight_mapping:
# if key in save_keymap:
# new_keymap[save_keymap[key]] = value
# else:
# print(f"Key {key} not found in keymap")
# new_keymap[key] = value
# self.weight_mapping = new_keymap
# else:
# print("No keymap found. Using default names")
# return
def forward(self, img_embeds):
# expand token rank if only rank 2
if len(img_embeds.shape) == 2:
img_embeds = img_embeds.unsqueeze(1)
# resample the image embeddings
img_embeds = self.resampler(img_embeds)
img_embeds = self.proj_module(img_embeds)
if len(img_embeds.shape) == 3:
# merge the heads
img_embeds = img_embeds.mean(dim=1)
self.img_embeds = []
# get all the slices
start = 0
for length in self.embed_lengths:
self.img_embeds.append(img_embeds[:, start:start+length])
start += length
def get_additional_save_metadata(self) -> Dict[str, Any]:
# save the weight mapping
return {
"weight_mapping": self.weight_mapping,
"num_heads": self.num_heads,
"vision_hidden_size": self.vision_hidden_size,
"head_dim": self.head_dim,
"vision_tokens": self.vision_tokens,
"output_size": self.output_size,
}

View File

@@ -156,10 +156,10 @@ class InstantLoRAMidModule(torch.nn.Module):
weight_chunk = weight_chunk.view(self.down_shape)
# check if is conv or linear
if len(weight_chunk.shape) == 4:
padding = 0
if weight_chunk.shape[-1] == 3:
padding = 1
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
org_module = self.lora_module_ref().orig_module_ref()
stride = org_module.stride
padding = org_module.padding
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride)
else:
# run a simple linear layer with the down weight
x_chunk = x_chunk @ weight_chunk.T

View File

@@ -6,7 +6,9 @@ import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
from diffusers.models.embeddings import PixArtAlphaTextProjection
from toolkit import train_tools
from toolkit.paths import REPOS_ROOT
@@ -17,11 +19,71 @@ sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter
class TEAdapterCaptionProjection(nn.Module):
def __init__(self, caption_channels, adapter: 'TEAdapter'):
super().__init__()
in_features = caption_channels
self.adapter_ref: weakref.ref = weakref.ref(adapter)
sd = adapter.sd_ref()
self.parent_module_ref = weakref.ref(sd.transformer.caption_projection)
parent_module = self.parent_module_ref()
self.linear_1 = nn.Linear(
in_features=in_features,
out_features=parent_module.linear_1.out_features,
bias=True
)
self.linear_2 = nn.Linear(
in_features=parent_module.linear_2.in_features,
out_features=parent_module.linear_2.out_features,
bias=True
)
# save the orig forward
parent_module.linear_1.orig_forward = parent_module.linear_1.forward
parent_module.linear_2.orig_forward = parent_module.linear_2.forward
# replace original forward
parent_module.orig_forward = parent_module.forward
parent_module.forward = self.forward
@property
def is_active(self):
return self.adapter_ref().is_active
@property
def unconditional_embeds(self):
return self.adapter_ref().adapter_ref().unconditional_embeds
@property
def conditional_embeds(self):
return self.adapter_ref().adapter_ref().conditional_embeds
def forward(self, caption):
if self.is_active and self.conditional_embeds is not None:
adapter_hidden_states = self.conditional_embeds.text_embeds
# check if we are doing unconditional
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]:
# concat unconditional to match the hidden state batch size
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
else:
unconditional = self.unconditional_embeds.text_embeds
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
hidden_states = self.linear_1(adapter_hidden_states)
hidden_states = self.parent_module_ref().act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
else:
return self.parent_module_ref().orig_forward(caption)
class TEAdapterAttnProcessor(nn.Module):
r"""
Attention processor for Custom TE for PyTorch 2.0.
@@ -177,6 +239,8 @@ class TEAdapter(torch.nn.Module):
self.te_ref: weakref.ref = weakref.ref(te)
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
self.adapter_modules = []
self.caption_projection = None
self.embeds_store = []
is_pixart = sd.is_pixart
if self.adapter_ref().config.text_encoder_arch == "t5":
@@ -297,6 +361,11 @@ class TEAdapter(torch.nn.Module):
transformer.transformer_blocks[i].attn2.processor for i in
range(len(transformer.transformer_blocks))
])
self.caption_projection = TEAdapterCaptionProjection(
caption_channels=self.token_size,
adapter=self,
)
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())