mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
WIP to add the caption_proj weight to pixart sigma TE adapter
This commit is contained in:
@@ -2,15 +2,20 @@ import os
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||
te_path = "google/flan-t5-xl"
|
||||
te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||
# te_path = "google/flan-t5-xl"
|
||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS"
|
||||
te_path = "google/flan-t5-base"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS_t5base_raw"
|
||||
|
||||
|
||||
print("Loading te adapter")
|
||||
te_aug_sd = load_file(te_aug_path)
|
||||
@@ -18,10 +23,18 @@ te_aug_sd = load_file(te_aug_path)
|
||||
print("Loading model")
|
||||
is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path)
|
||||
|
||||
# if "pixart" in model_path.lower():
|
||||
is_pixart = "pixart" in model_path.lower()
|
||||
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
if is_pixart:
|
||||
pipeline_class = PixArtSigmaPipeline
|
||||
|
||||
if is_diffusers:
|
||||
sd = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
else:
|
||||
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
|
||||
print("Loading Text Encoder")
|
||||
# Load the text encoder
|
||||
@@ -31,23 +44,49 @@ te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16)
|
||||
sd.text_encoder = te
|
||||
sd.tokenizer = T5Tokenizer.from_pretrained(te_path)
|
||||
|
||||
unet_sd = sd.unet.state_dict()
|
||||
if is_pixart:
|
||||
unet = sd.transformer
|
||||
unet_sd = sd.transformer.state_dict()
|
||||
else:
|
||||
unet = sd.transformer
|
||||
unet_sd = sd.unet.state_dict()
|
||||
|
||||
weight_idx = 1
|
||||
|
||||
if is_pixart:
|
||||
weight_idx = 0
|
||||
else:
|
||||
weight_idx = 1
|
||||
|
||||
new_cross_attn_dim = None
|
||||
|
||||
print("Patching UNet")
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
# count the num of params in state dict
|
||||
start_params = sum([v.numel() for v in unet_sd.values()])
|
||||
|
||||
print("Building")
|
||||
attn_processor_keys = []
|
||||
if is_pixart:
|
||||
transformer: Transformer2DModel = unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
|
||||
# cross attention
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||
else:
|
||||
attn_processor_keys = list(unet.attn_processors.keys())
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith(
|
||||
"attn1") else \
|
||||
unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
hidden_size = unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
|
||||
hidden_size = list(reversed(unet.config['block_out_channels']))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
hidden_size = unet.config['block_out_channels'][block_id]
|
||||
elif name.startswith("transformer"):
|
||||
hidden_size = unet.config['cross_attention_dim']
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
@@ -60,7 +99,10 @@ for name in sd.unet.attn_processors.keys():
|
||||
|
||||
te_aug_name = None
|
||||
while True:
|
||||
te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
|
||||
if is_pixart:
|
||||
te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
|
||||
else:
|
||||
te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
|
||||
if f"{te_aug_name}.weight" in te_aug_sd:
|
||||
# increment so we dont redo it next time
|
||||
weight_idx += 1
|
||||
@@ -86,7 +128,10 @@ sd.save_pretrained(
|
||||
)
|
||||
|
||||
# overwrite the unet
|
||||
unet_folder = os.path.join(output_path, "unet")
|
||||
if is_pixart:
|
||||
unet_folder = os.path.join(output_path, "transformer")
|
||||
else:
|
||||
unet_folder = os.path.join(output_path, "unet")
|
||||
|
||||
# move state_dict to cpu
|
||||
unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()}
|
||||
@@ -94,7 +139,7 @@ unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()}
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
print("Patching new unet")
|
||||
print("Patching")
|
||||
|
||||
save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta)
|
||||
|
||||
@@ -104,8 +149,17 @@ with open(os.path.join(unet_folder, "config.json"), 'r') as f:
|
||||
|
||||
config['cross_attention_dim'] = new_cross_attn_dim
|
||||
|
||||
if is_pixart:
|
||||
config['caption_channels'] = te.config.d_model
|
||||
|
||||
# save it
|
||||
with open(os.path.join(unet_folder, "config.json"), 'w') as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
print("Done")
|
||||
|
||||
new_params = sum([v.numel() for v in unet_sd.values()])
|
||||
|
||||
# print new and old params with , formatted
|
||||
print(f"Old params: {start_params:,}")
|
||||
print(f"New params: {new_params:,}")
|
||||
|
||||
76
testing/shrink_pixart.py
Normal file
76
testing/shrink_pixart.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||
# te_path = "google/flan-t5-xl"
|
||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
# has 28 blocks
|
||||
# keep block 0 and 27
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# move non blocks over
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("transformer_blocks."):
|
||||
new_state_dict[key] = value
|
||||
|
||||
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
|
||||
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
|
||||
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
|
||||
'transformer_blocks.{idx}.scale_shift_table']
|
||||
|
||||
# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27
|
||||
|
||||
current_idx = 0
|
||||
for i in range(28):
|
||||
if i not in [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]:
|
||||
# todo merge in with previous block
|
||||
for name in block_names:
|
||||
try:
|
||||
new_state_dict_key = name.format(idx=current_idx - 1)
|
||||
old_state_dict_key = name.format(idx=i)
|
||||
new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5)
|
||||
except KeyError:
|
||||
raise KeyError(f"KeyError: {name.format(idx=current_idx)}")
|
||||
else:
|
||||
for name in block_names:
|
||||
new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)]
|
||||
current_idx += 1
|
||||
|
||||
|
||||
# make sure they are all fp16 and on cpu
|
||||
for key, value in new_state_dict.items():
|
||||
new_state_dict[key] = value.to(torch.float16).cpu()
|
||||
|
||||
# save the new state dict
|
||||
save_file(new_state_dict, output_path, metadata=meta)
|
||||
|
||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
||||
|
||||
# porint comma formatted
|
||||
print(f"Old param count: {old_param_count:,}")
|
||||
print(f"New param count: {new_param_count:,}")
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import weakref
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@@ -59,6 +60,7 @@ class LoRAModule(ToolkitModuleMixin, ExtractableModuleMixin, torch.nn.Module):
|
||||
ToolkitModuleMixin.__init__(self, network=network)
|
||||
torch.nn.Module.__init__(self)
|
||||
self.lora_name = lora_name
|
||||
self.orig_module_ref = weakref.ref(org_module)
|
||||
self.scalar = torch.tensor(1.0)
|
||||
# check if parent has bias. if not force use_bias to False
|
||||
if org_module.bias is None:
|
||||
|
||||
267
toolkit/models/LoRAFormer.py
Normal file
267
toolkit/models/LoRAFormer.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import math
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, List, Dict, Any
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from collections import OrderedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, d_model, nhead, dim_feedforward):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||||
self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
nn.ReLU(),
|
||||
nn.Linear(dim_feedforward, d_model)
|
||||
)
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x, cross_attn_input):
|
||||
# Self-attention
|
||||
attn_output, _ = self.self_attn(x, x, x)
|
||||
x = self.norm1(x + attn_output)
|
||||
|
||||
# Cross-attention
|
||||
cross_attn_output, _ = self.cross_attn(x, cross_attn_input, cross_attn_input)
|
||||
x = self.norm2(x + cross_attn_output)
|
||||
|
||||
# Feed-forward
|
||||
ff_output = self.feed_forward(x)
|
||||
x = self.norm3(x + ff_output)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
lora_module: 'LoRAModule',
|
||||
instant_lora_module: 'InstantLoRAModule',
|
||||
up_shape: list = None,
|
||||
down_shape: list = None,
|
||||
):
|
||||
super(InstantLoRAMidModule, self).__init__()
|
||||
self.up_shape = up_shape
|
||||
self.down_shape = down_shape
|
||||
self.index = index
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.embed = None
|
||||
|
||||
def down_forward(self, x, *args, **kwargs):
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
down_size = math.prod(self.down_shape)
|
||||
down_weight = self.embed[:, :down_size]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if down_weight.shape[0] * 2 == batch_size:
|
||||
down_weight = torch.cat([down_weight] * 2, dim=0)
|
||||
|
||||
weight_chunks = torch.chunk(down_weight, batch_size, dim=0)
|
||||
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||
|
||||
x_out = []
|
||||
for i in range(batch_size):
|
||||
weight_chunk = weight_chunks[i]
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.down_shape)
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
padding = 0
|
||||
if weight_chunk.shape[-1] == 3:
|
||||
padding = 1
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
|
||||
|
||||
def up_forward(self, x, *args, **kwargs):
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
up_size = math.prod(self.up_shape)
|
||||
up_weight = self.embed[:, -up_size:]
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# unconditional
|
||||
if up_weight.shape[0] * 2 == batch_size:
|
||||
up_weight = torch.cat([up_weight] * 2, dim=0)
|
||||
|
||||
weight_chunks = torch.chunk(up_weight, batch_size, dim=0)
|
||||
x_chunks = torch.chunk(x, batch_size, dim=0)
|
||||
|
||||
x_out = []
|
||||
for i in range(batch_size):
|
||||
weight_chunk = weight_chunks[i]
|
||||
x_chunk = x_chunks[i]
|
||||
# reshape
|
||||
weight_chunk = weight_chunk.view(self.up_shape)
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
padding = 0
|
||||
if weight_chunk.shape[-1] == 3:
|
||||
padding = 1
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
x_out.append(x_chunk)
|
||||
x = torch.cat(x_out, dim=0)
|
||||
return x
|
||||
|
||||
|
||||
# Initialize the network
|
||||
# num_blocks = 8
|
||||
# d_model = 1024 # Adjust as needed
|
||||
# nhead = 16 # Adjust as needed
|
||||
# dim_feedforward = 4096 # Adjust as needed
|
||||
# latent_dim = 1695744
|
||||
|
||||
class LoRAFormer(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks,
|
||||
d_model=1024,
|
||||
nhead=16,
|
||||
dim_feedforward=4096,
|
||||
sd: 'StableDiffusion'=None,
|
||||
):
|
||||
super(LoRAFormer, self).__init__()
|
||||
# self.linear = torch.nn.Linear(2, 1)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.dim = sd.network.lora_dim
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: List[torch.Tensor] = None
|
||||
|
||||
# disable merging in. It is slower on inference
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
self.ilora_modules = torch.nn.ModuleList()
|
||||
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
output_size = 0
|
||||
|
||||
self.embed_lengths = []
|
||||
self.weight_mapping = []
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
module_dict = lora_module.state_dict()
|
||||
down_shape = list(module_dict['lora_down.weight'].shape)
|
||||
up_shape = list(module_dict['lora_up.weight'].shape)
|
||||
|
||||
self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]])
|
||||
|
||||
module_size = math.prod(down_shape) + math.prod(up_shape)
|
||||
output_size += module_size
|
||||
self.embed_lengths.append(module_size)
|
||||
|
||||
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
instant_module = InstantLoRAMidModule(
|
||||
idx,
|
||||
lora_module,
|
||||
self,
|
||||
up_shape=up_shape,
|
||||
down_shape=down_shape
|
||||
)
|
||||
|
||||
self.ilora_modules.append(instant_module)
|
||||
|
||||
# replace the LoRA forwards
|
||||
lora_module.lora_down.forward = instant_module.down_forward
|
||||
lora_module.lora_up.forward = instant_module.up_forward
|
||||
|
||||
|
||||
self.output_size = output_size
|
||||
|
||||
self.latent = nn.Parameter(torch.randn(1, output_size))
|
||||
self.latent_proj = nn.Linear(output_size, d_model)
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerBlock(d_model, nhead, dim_feedforward)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.final_proj = nn.Linear(d_model, output_size)
|
||||
|
||||
self.migrate_weight_mapping()
|
||||
|
||||
def migrate_weight_mapping(self):
|
||||
return
|
||||
# # changes the names of the modules to common ones
|
||||
# keymap = self.sd_ref().network.get_keymap()
|
||||
# save_keymap = {}
|
||||
# if keymap is not None:
|
||||
# for ldm_key, diffusers_key in keymap.items():
|
||||
# # invert them
|
||||
# save_keymap[diffusers_key] = ldm_key
|
||||
#
|
||||
# new_keymap = {}
|
||||
# for key, value in self.weight_mapping:
|
||||
# if key in save_keymap:
|
||||
# new_keymap[save_keymap[key]] = value
|
||||
# else:
|
||||
# print(f"Key {key} not found in keymap")
|
||||
# new_keymap[key] = value
|
||||
# self.weight_mapping = new_keymap
|
||||
# else:
|
||||
# print("No keymap found. Using default names")
|
||||
# return
|
||||
|
||||
|
||||
def forward(self, img_embeds):
|
||||
# expand token rank if only rank 2
|
||||
if len(img_embeds.shape) == 2:
|
||||
img_embeds = img_embeds.unsqueeze(1)
|
||||
|
||||
# resample the image embeddings
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
img_embeds = self.proj_module(img_embeds)
|
||||
if len(img_embeds.shape) == 3:
|
||||
# merge the heads
|
||||
img_embeds = img_embeds.mean(dim=1)
|
||||
|
||||
self.img_embeds = []
|
||||
# get all the slices
|
||||
start = 0
|
||||
for length in self.embed_lengths:
|
||||
self.img_embeds.append(img_embeds[:, start:start+length])
|
||||
start += length
|
||||
|
||||
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
# save the weight mapping
|
||||
return {
|
||||
"weight_mapping": self.weight_mapping,
|
||||
"num_heads": self.num_heads,
|
||||
"vision_hidden_size": self.vision_hidden_size,
|
||||
"head_dim": self.head_dim,
|
||||
"vision_tokens": self.vision_tokens,
|
||||
"output_size": self.output_size,
|
||||
}
|
||||
|
||||
@@ -156,10 +156,10 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
weight_chunk = weight_chunk.view(self.down_shape)
|
||||
# check if is conv or linear
|
||||
if len(weight_chunk.shape) == 4:
|
||||
padding = 0
|
||||
if weight_chunk.shape[-1] == 3:
|
||||
padding = 1
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding)
|
||||
org_module = self.lora_module_ref().orig_module_ref()
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride)
|
||||
else:
|
||||
# run a simple linear layer with the down weight
|
||||
x_chunk = x_chunk @ weight_chunk.T
|
||||
|
||||
@@ -6,7 +6,9 @@ import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING
|
||||
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPTextModelWithProjection
|
||||
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -17,11 +19,71 @@ sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
class TEAdapterCaptionProjection(nn.Module):
|
||||
def __init__(self, caption_channels, adapter: 'TEAdapter'):
|
||||
super().__init__()
|
||||
in_features = caption_channels
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
sd = adapter.sd_ref()
|
||||
self.parent_module_ref = weakref.ref(sd.transformer.caption_projection)
|
||||
parent_module = self.parent_module_ref()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features=in_features,
|
||||
out_features=parent_module.linear_1.out_features,
|
||||
bias=True
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
in_features=parent_module.linear_2.in_features,
|
||||
out_features=parent_module.linear_2.out_features,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# save the orig forward
|
||||
parent_module.linear_1.orig_forward = parent_module.linear_1.forward
|
||||
parent_module.linear_2.orig_forward = parent_module.linear_2.forward
|
||||
|
||||
# replace original forward
|
||||
parent_module.orig_forward = parent_module.forward
|
||||
parent_module.forward = self.forward
|
||||
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
@property
|
||||
def unconditional_embeds(self):
|
||||
return self.adapter_ref().adapter_ref().unconditional_embeds
|
||||
|
||||
@property
|
||||
def conditional_embeds(self):
|
||||
return self.adapter_ref().adapter_ref().conditional_embeds
|
||||
|
||||
def forward(self, caption):
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
adapter_hidden_states = self.conditional_embeds.text_embeds
|
||||
# check if we are doing unconditional
|
||||
if self.unconditional_embeds is not None and adapter_hidden_states.shape[0] != caption.shape[0]:
|
||||
# concat unconditional to match the hidden state batch size
|
||||
if self.unconditional_embeds.text_embeds.shape[0] == 1 and adapter_hidden_states.shape[0] != 1:
|
||||
unconditional = torch.cat([self.unconditional_embeds.text_embeds] * adapter_hidden_states.shape[0], dim=0)
|
||||
else:
|
||||
unconditional = self.unconditional_embeds.text_embeds
|
||||
adapter_hidden_states = torch.cat([unconditional, adapter_hidden_states], dim=0)
|
||||
hidden_states = self.linear_1(adapter_hidden_states)
|
||||
hidden_states = self.parent_module_ref().act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
else:
|
||||
return self.parent_module_ref().orig_forward(caption)
|
||||
|
||||
|
||||
class TEAdapterAttnProcessor(nn.Module):
|
||||
r"""
|
||||
Attention processor for Custom TE for PyTorch 2.0.
|
||||
@@ -177,6 +239,8 @@ class TEAdapter(torch.nn.Module):
|
||||
self.te_ref: weakref.ref = weakref.ref(te)
|
||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||
self.adapter_modules = []
|
||||
self.caption_projection = None
|
||||
self.embeds_store = []
|
||||
is_pixart = sd.is_pixart
|
||||
|
||||
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||
@@ -297,6 +361,11 @@ class TEAdapter(torch.nn.Module):
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
self.caption_projection = TEAdapterCaptionProjection(
|
||||
caption_channels=self.token_size,
|
||||
adapter=self,
|
||||
)
|
||||
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
@@ -289,7 +289,13 @@ class ToolkitModuleMixin:
|
||||
scaled_lora_weight = lora_weight * scale
|
||||
scaled_lora_output = scaled_lora_output + self.apply_dora(lx, scaled_lora_weight)
|
||||
|
||||
x = org_forwarded + scaled_lora_output
|
||||
try:
|
||||
x = org_forwarded + scaled_lora_output
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(org_forwarded.size())
|
||||
print(scaled_lora_output.size())
|
||||
raise e
|
||||
return x
|
||||
|
||||
def enable_gradient_checkpointing(self: Module):
|
||||
|
||||
@@ -309,6 +309,9 @@ class StableDiffusion:
|
||||
main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
|
||||
if self.model_config.is_pixart_sigma:
|
||||
main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
|
||||
|
||||
main_model_path = model_path
|
||||
|
||||
# load the TE in 8bit mode
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
main_model_path,
|
||||
|
||||
Reference in New Issue
Block a user